whisper 模型源码解读

news/2024/9/25 17:14:30/

在这里插入图片描述

whisper_1">whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optionalimport numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function@dataclass
class ModelDimensions:"""该类用于存储模型的各项参数"""n_mels: int  # Mel谱图的频带数量n_audio_ctx: int  # 音频上下文窗口大小n_audio_state: int  # 音频状态维度n_audio_head: int  # 音频注意力头数量n_audio_layer: int  # 音频层数量n_vocab: int  # 词汇表大小n_text_ctx: int  # 文本上下文窗口大小n_text_state: int  # 文本状态维度n_text_head: int  # 文本注意力头数量n_text_layer: int  # 文本层数量class LayerNorm(nn.LayerNorm):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保输入张量的类型在归一化前后保持一致"""return super().forward(x.float()).type(x.dtype)class Linear(nn.Linear):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保权重和偏置与输入张量的类型一致"""return F.linear(x,self.weight.to(x.dtype),None if self.bias is None else self.bias.to(x.dtype),)class Conv1d(nn.Conv1d):def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:"""重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致"""return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))def sinusoids(length, channels, max_timescale=10000):"""生成用于位置嵌入的正弦曲线"""assert channels % 2 == 0log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)class MultiHeadAttention(nn.Module):def __init__(self, n_state: int, n_head: int):"""初始化多头注意力层"""super().__init__()self.n_head = n_headself.query = Linear(n_state, n_state)self.key = Linear(n_state, n_state, bias=False)self.value = Linear(n_state, n_state)self.out = Linear(n_state, n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""多头注意力的前向传播"""q = self.query(x)if kv_cache is None or xa is None or self.key not in kv_cache:# 如果没有缓存键和值,则正常计算k = self.key(x if xa is None else xa)v = self.value(x if xa is None else xa)else:# 如果有缓存,则使用缓存的键和值k = kv_cache[self.key]v = kv_cache[self.value]wv, qk = self.qkv_attention(q, k, v, mask)return self.out(wv), qkdef qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):"""计算 QKV 注意力"""n_batch, n_ctx, n_state = q.shapescale = (n_state // self.n_head) ** -0.25q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scalek = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scalev = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)qk = q @ kif mask is not None:qk = qk + mask[:n_ctx, :n_ctx]qk = qk.float()w = F.softmax(qk, dim=-1).to(q.dtype)return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()class ResidualAttentionBlock(nn.Module):def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):"""初始化残差注意力块"""super().__init__()self.attn = MultiHeadAttention(n_state, n_head)self.attn_ln = LayerNorm(n_state)self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)self.cross_attn_ln = LayerNorm(n_state) if cross_attention else Nonen_mlp = n_state * 4self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))self.mlp_ln = LayerNorm(n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""残差注意力块的前向传播"""x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]if self.cross_attn:x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]x = x + self.mlp(self.mlp_ln(x))return xclass AudioEncoder(nn.Module):def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化音频编码器"""super().__init__()self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])self.ln_post = LayerNorm(n_state)def forward(self, x: Tensor):"""前向传播,处理音频输入x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)音频的Mel谱图"""x = F.gelu(self.conv1(x))x = F.gelu(self.conv2(x))x = x.permute(0, 2, 1)assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"x = (x + self.positional_embedding).to(x.dtype)for block in self.blocks:x = block(x)x = self.ln_post(x)return xclass TextDecoder(nn.Module):def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化文本解码器"""super().__init__()self.token_embedding = nn.Embedding(n_vocab, n_state)self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True)for _ in range(n_layer)])self.ln = LayerNorm(n_state)mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)self.register_buffer("mask", mask, persistent=False)def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):"""前向传播,处理文本输入并结合音频特征x : torch.LongTensor, shape = (batch_size, <= n_ctx)文本的标记序列xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)编码后的音频特征"""offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0x = (self.token_embedding(x)+ self.positional_embedding[offset : offset + x.shape[-1]])x = x.to(xa.dtype)for block in self.blocks:x = block(x, xa, mask=self.mask, kv_cache=kv_cache)x = self.ln(x)logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()return logitsclass Whisper(nn.Module):def __init__(self, dims: ModelDimensions):"""初始化 Whisper 模型"""super().__init__()self.dims = dimsself.encoder = AudioEncoder(self.dims.n_mels,self.dims.n_audio_ctx,self.dims.n_audio_state,self.dims.n_audio_head,self.dims.n_audio_layer,)self.decoder = TextDecoder(self.dims.n_vocab,self.dims.n_text_ctx,self.dims.n_text_state,self.dims.n_text_head,self.dims.n_text_layer,)# 默认情况下,使用解码器层的后一半进行时间对齐;# 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)all_heads[self.dims.n_text_layer // 2 :] = Trueself.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)def set_alignment_heads(self, dump: bytes):"""设置对齐的注意力头"""array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)def embed_audio(self, mel: torch.Tensor):"""编码音频特征"""return self.encoder(mel)def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):"""获取预测的logits"""return self.decoder(tokens, audio_features)def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:"""前向传播"""return self.decoder(tokens, self.encoder(mel))@propertydef device(self):"""获取模型所在的设备"""return next(self.parameters()).device@propertydef is_multilingual(self):"""判断模型是否支持多语言"""return self.dims.n_vocab >= 51865@propertydef num_languages(self):"""获取模型支持的语言数量"""return self.dims.n_vocab - 51765 - int(self.is_multilingual)def install_kv_cache_hooks(self, cache: Optional[dict] = None):"""为键和值的投影模块安装缓存钩子返回-------cache : Dict[nn.Module, torch.Tensor]映射键/值投影模块到其缓存的字典对象hooks : List[RemovableHandle]用于停止调用钩子的 PyTorch RemovableHandle 对象列表"""cache = {**cache} if cache is not None else {}hooks = []def save_to_cache(module, _, output):if module not in cache or output.shape[1] > self.dims.n_text_ctx:# 第一次标记或交叉注意时保存原始值cache[module] = outputelse:cache[module] = torch.cat([cache[module], output], dim=1).detach()return cache[module]def install_hooks(layer: nn.Module):if isinstance(layer, MultiHeadAttention):hooks.append(layer.key.register_forward_hook(save_to_cache))hooks.append(layer.value.register_forward_hook(save_to_cache))self.decoder.apply(install_hooks)return cache, hooksdetect_language = detect_language_function  # 语言检测函数transcribe = transcribe_function  # 转录函数decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

import torch# 初始化模型参数
dims = ModelDimensions(n_mels=80,n_audio_ctx=1500,n_audio_state=512,n_audio_head=8,n_audio_layer=6,n_vocab=51865,n_text_ctx=448,n_text_state=512,n_text_head=8,n_text_layer=6,
)# 创建模型实例
model = Whisper(dims)# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)# 最终生成的文本标记序列
final_tokens = initial_tokens# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

for _ in range(10):logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。


http://www.ppmy.cn/news/1470344.html

相关文章

微信聊天记录导出为电脑文件实操教程(附代码)

写在前面 最近&#xff0c;微信中加的群有点多&#xff0c;信息根本看不过来。如果不看&#xff0c;怕遗漏了有价值的信息&#xff1b;如果一条条向上翻阅&#xff0c;实在是太麻烦。 有没有办法一键导出所有聊天记录&#xff1f; 一来翻阅更方便一点&#xff0c;二来还可以…

国际期货行情相关术语

1&#xff09;合约&#xff1a;期货行情表提供了期货交易的相关信息 &#xff0c;行情表中每一个期货合约都有合约代码&#xff08;由期货合约交易代码和合约到期月份组成&#xff09;来标识。 &#xff08;2&#xff09;开盘价&#xff1a;当日某一期货合约交易开始前五分钟集…

QT——事件

一、什么是事件 在QT中,事件(Event)是指由特定对象发生的动作或状态变化,通常用于响应用户的操作。事件可以是鼠标点击、键盘输入、窗口移动等用户操作,也可以是系统发出的信号,比如定时器超时、网络数据到达等。在QT中,可以通过连接信号与槽(Signals and Slots)的方…

服务器被墙是什么原因,怎么解决服务器被墙

服务器被墙通常是由于以下几个原因&#xff1a; 网络监管&#xff1a;某些国家或地区会对网络进行严格的监管&#xff0c;包括对特定网站、应用程序或服务进行屏蔽或封锁。这种情况下&#xff0c;服务器可能会被封锁&#xff0c;导致无法访问。 安全问题&#xff1a;服务器被发…

CDAM|数据资产管理:解锁企业价值的金钥匙

随着信息技术的飞速发展&#xff0c;数据已经成为企业最重要的资产之一。有效地管理数据资产&#xff0c;不仅有助于提升企业的运营效率&#xff0c;更能为企业的战略决策提供有力支持。本文将深入探讨数据资产管理的重要性、挑战以及实施策略&#xff0c;以期为企业打造一套高…

Linux - 进程

一、什么是进程 首先&#xff0c;Linux是一个多用户多进程的操作系统&#xff0c;系统上可以同时运行多个进程。 进程的产生&#xff1a;①是在执行程序或者命令时产生的&#xff1b;②定时任务进程 进程的类型&#xff1a;前台进程/后台进程 前台进程&#xff1a;一个终端…

AI 音乐大模型:创新的曙光还是创意产业的阴影?

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

分享一下,如何搭建个人网站的步骤

在这段充满探索与创造的奇妙旅途中&#xff0c;我就像一位耐心的建筑师&#xff0c;在数字世界的荒原上精心雕琢&#xff0c;两周的时光缓缓流淌。每天&#xff0c;我与代码共舞&#xff0c;手执HTML、CSS与JavaScript这三大构建魔杖&#xff0c;一砖一瓦地筑起了梦想中的网络城…