自然语言处理:第六十四章 Qwen2代码解析

server/2024/11/24 7:01:55/

本人项目地址大全:Victor94-king/NLP__ManVictor: CSDN of ManVictor

原文地址:微信公众平台

项目地址: QwenLM/Qwen2.5: Qwen2.5 is the large language model series developed by Qwen team, Alibaba Cloud.

官网地址: 你好,Qwen2 | Qwen & Qwen2.5: 基础模型大派对! | Qwen


写在前面: 笔者更新不易,希望走过路过点个关注和赞,笔芯!!!

写在前面: 笔者更新不易,希望走过路过点个关注和赞,笔芯!!!

写在前面: 笔者更新不易,希望走过路过点个关注和赞,笔芯!!!




下面的源码内容来自transformers代码库中:transformers-4.45.2/src/transformers/models/qwen2/modeling_qwen2.py

实验准备

首先我们下载一些Qwen2需要的配置数据。下载地址:https://hf-mirror.com/Qwen/Qwen2-0.5B/tree/main

# 下载配置相关的文件
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/config.json
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/generation_config.json
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/merges.txt
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/tokenizer.json
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/tokenizer_config.json
wget https://hf-mirror.com/Qwen/Qwen2-0.5B/resolve/main/vocab.json

注:权重文件我们可以不下载,我们这里仅仅是做一些流程实验,所以权重可以使用随机初始化。

下载transformers源码,我们这里使用的是 4.45.2版本,理论上之后的版本也都支持。

config.json文件修改

原始文件内容:

{"architectures": ["Qwen2ForCausalLM"],"attention_dropout": 0.0,"bos_token_id": 151643,"eos_token_id": 151643,"hidden_act": "silu","hidden_size": 896,"initializer_range": 0.02,"intermediate_size": 4864,"max_position_embeddings": 131072,"max_window_layers": 24,"model_type": "qwen2","num_attention_heads": 14,"num_hidden_layers": 24,"num_key_value_heads": 2,"rms_norm_eps": 1e-06,"rope_theta": 1000000.0,"sliding_window": 131072,"tie_word_embeddings": true,"torch_dtype": "bfloat16","transformers_version": "4.40.1","use_cache": true,"use_sliding_window": false,"vocab_size": 151936
}

我们这里修改 num_hidden_layers值为 4use_cache设置为 false,因为我们仅仅是实验一下,并不需要那么多层。其它内容保持不变。

文件结构

在transformers目录的examples目录里面新建一个Qwen2_learn目录,在Qwen2_learn下再建一个config文件夹,然后将上面下载的配置文件复制到config目录下。最终或得如下目录结构:

├── __init__.py
├── config
│   ├── config.json
│   ├── generation_config.json
│   ├── merges.txt
│   ├── tokenizer.json
│   ├── tokenizer_config.json
│   └── vocab.json
└── main.py

主要代码

下面是主体代码:

from src.transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from src.transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from src.transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLMconfig = Qwen2Config.from_pretrained("./config")
tokenizer = Qwen2Tokenizer.from_pretrained("./config")
model = Qwen2ForCausalLM(config=config)
print("模型细节: ")
print(model)
print("*="*80)
print("文本编码:")
inputs = tokenizer(["你好啊", "简单的机器学习是为了让机器学习变得更简单而存在的"],add_special_tokens=True,max_length=10,padding=True,truncation=True,return_tensors="pt")
print(inputs)
print("*="*80)
print("模型输出:")
print(model(**inputs))

不出意外的话,你可以看到下面的输出内容:

模型细节: 
Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 896)(layers): ModuleList((0-3): 4 x Qwen2DecoderLayer((self_attn): Qwen2SdpaAttention((q_proj): Linear(in_features=896, out_features=896, bias=True)(k_proj): Linear(in_features=896, out_features=128, bias=True)(v_proj): Linear(in_features=896, out_features=128, bias=True)(o_proj): Linear(in_features=896, out_features=896, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=896, out_features=4864, bias=False)(up_proj): Linear(in_features=896, out_features=4864, bias=False)(down_proj): Linear(in_features=4864, out_features=896, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)(post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)))(norm): Qwen2RMSNorm((896,), eps=1e-06)(rotary_emb): Qwen2RotaryEmbedding())(lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
文本编码:
{'input_ids': tensor([[108386, 103924, 151643, 151643, 151643, 151643, 151643, 151643, 151643,151643],[105172, 102182, 100134, 104802,  99258, 102182, 100134, 112606, 100405,68536]]), 'attention_mask': tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
模型输出:
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 1.5059,  0.6765,  0.2425,  ...,  0.4329, -0.0789, -1.0450],[ 0.3321,  0.8809,  0.6826,  ...,  0.0330,  0.0865, -0.6893],[ 0.2471,  0.7115,  0.5307,  ..., -0.0703,  0.1209, -0.7370],...,[ 0.3910,  0.7432,  0.3905,  ...,  0.0459,  0.2107, -0.6613],[ 0.3790,  0.7864,  0.4028,  ...,  0.0793,  0.2166, -0.6966],[ 0.3704,  0.8088,  0.4358,  ...,  0.0567,  0.2196, -0.7045]],[[ 1.4859, -0.7797,  0.9490,  ..., -0.0205, -0.2090, -0.7289],[ 1.5965, -0.2371,  0.7803,  ..., -0.8275, -0.1699, -0.0016],[ 1.2100, -0.2230,  0.8658,  ..., -0.0166, -0.0579, -0.5130],...,[ 0.5131, -0.2756,  0.8019,  ..., -0.0026,  0.3006, -1.2691],[ 0.2210, -0.0853,  0.9619,  ..., -0.1808,  0.5546, -1.0678],[ 0.4743,  0.1699,  0.6723,  ..., -0.0445,  0.4406, -0.9143]]],grad_fn=<UnsafeViewBackward0>), past_key_values=None, hidden_states=None, attentions=None)

有了上面的内容,我们基本流程就是搭好了,下面就可以使用我们自己喜欢的IDEA进行各种内容的调试了。我这里使用的是 pycharm

Qwen2Model

Qwen2ForCausalLM主体主要是 Qwen2Model,所以我们主要来看一下 Qwen2Model中的输入输出部分。

输入

对于 Qwen2Model的输入主要是以下参数

input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
  • input_ids的shape是 [bs, seq_len],即batch_size和序列的长度组成的二维矩阵。里面的元素值是token在词汇表中对应的索引信息。

  • attention_mask的shape和 input_ids shape是一直的,也是 [bs, seq_len],元素取值要么是1,要么是0,1表示 input_ids对应位置的元素是有效的,0则表示无效的,在后续计算attention时,只有为1的位置才会被真正的计算。

  • position_ids的shape也是 [bs, seq_len],表达元素的位置的信息。

  • past_key_values:预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可以用来加速序列解码。这通常包括模型在解码的前一阶段返回的 past_key_values,当 use_cache=Trueconfig.use_cache=True时。

    允许两种格式:

    模型将输出与输入相同的缓存格式。如果没有传递 past_key_values,将返回传统的缓存格式。

    如果使用了 past_key_values,用户可以选择性地只输入最后一个 input_ids(那些没有给这个模型提供过去键值状态的),形状为 (batch_size, 1),而不是所有 input_ids的形状 (batch_size, sequence_length)

    注:这个参数一般情况在推理的时候使用,训练的时候不用。

    • 一个 ~cache_utils.Cache实例,参见我kv缓存指南;
    • 一个长度为 config.n_layers的元组,其中每个元组包含两个形状为 (batch_size, num_heads, sequence_length, embed_size_per_head)torch.FloatTensor张量。这也被称为传统的缓存格式。
  • inputs_embeds:形状为 (batch_size, sequence_length, hidden_size), 可选地,您可以选择不传递 input_ids,而是直接传递嵌入表示。这在您想要对如何将 input_ids 索引转换为相关向量有更多的控制权时很有用,而不是使用模型内部的嵌入查找矩阵。

  • use_cache:如果设置为 True,则返回 past_key_values 键值状态,可以用来加速解码(参见 past_key_values)。

  • output_attentions:是否返回所有注意力层的注意力张量。有关返回张量的更多详细信息,请参见返回张量中的 attentions

  • output_hidden_states:是否返回所有层的隐藏状态。有关返回张量的更多详细信息,请参见返回张量中的 hidden_states

  • return_dict:是否返回一个 ~utils.ModelOutput而不是一个普通的元组。

  • cache_position:描述输入序列标记位置的索引。与 position_ids 不同,这个张量不受填充(padding)的影响。它用于在正确的位置更新缓存,并推断完整的序列长度。

上面就是在 forward中所需要的所有参数。下面我们将结合代码的内容实现,以及参数的具体值来简单实验一下。通过实验过程来逐步理解代码逻辑。

["你好啊", "简单的机器学习是为了让机器学习变得更简单而存在的"]

这个样例产生的tokens结果为:

{'input_ids': tensor([[108386, 103924, 151643, 151643, 151643, 151643, 151643, 151643, 151643,151643],[105172, 102182, 100134, 104802,  99258, 102182, 100134, 112606, 100405,68536]]), 'attention_mask': tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

即得到的shape为:[2, 10]

由上一节 print(model)内容:

Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 896)(layers): ModuleList((0-3): 4 x Qwen2DecoderLayer((self_attn): Qwen2SdpaAttention((q_proj): Linear(in_features=896, out_features=896, bias=True)(k_proj): Linear(in_features=896, out_features=128, bias=True)(v_proj): Linear(in_features=896, out_features=128, bias=True)(o_proj): Linear(in_features=896, out_features=896, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=896, out_features=4864, bias=False)(up_proj): Linear(in_features=896, out_features=4864, bias=False)(down_proj): Linear(in_features=4864, out_features=896, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)(post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)))(norm): Qwen2RMSNorm((896,), eps=1e-06)(rotary_emb): Qwen2RotaryEmbedding())(lm_head): Linear(in_features=896, out_features=151936, bias=False)
)

我们看到 Qwen2Model主体是由 embed_tokens + 4*(self_attn + mlp + input_layernorm + post_attention_layernorm) + norm + rotary_emb组成的。

详情

embed_tokens层:

embed_tokens就是我们熟悉的 nn.Embedding初始化得到的层。即:

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

其中:config.vocab_size=151936, config.hidden_size=896, self.padding_idx=config.pad_token_id=None。当 self.padding_idxNone时候,默认取值就为0。

对于shape为 [2, 10]的输入,经过 embed_tokens层,可获得shape为 [2, 10, 896],记为 inputs_embeds

cache_position和position_ids:

因为 cache_position和position_ids都是 None(注:正对本样例而言),所以cache_position直接是通过传进来的序列长度计算得到的,即为:tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])position_ids为:tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

causal_mask:

causal_mask是由方法 self._update_causal_mask产生的,它将产生四维的矩阵数据,shape为 [2, 1, 10, 10],即 [bs, 1, seq_len, seq_len],我们这里展示一下最后两维的数据,分别是 causal_mask[0,0][:5, :5]causal_mask[1,0][:5, :5],如下:

# causal_mask[0,0][:5, :5]
tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38]])
# causal_mask[1,0][:5, :5]
tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]])

可以看出下三角全为0,causal_mask[0,0][:5, :5]不全为0的原因是由attention_mask引起的,即pad部分是不用去计算的。

rotary_emb:

rotary_emb层只计算一次,然后运用到后面的各层,这一层是没有参数的,不参与训练。使用旋转位置编码最直接的好处有:

  • 可以使用绝对位置编码来表示相对位置编码;
  • 计算量是线性的;
  • 通过配置,可以实现一定的长度往外延拓能力;

rotary_emb主要用于计算cos和sin的值,即计算公式:

class Qwen2RotaryEmbedding(nn.Module):def __init__(self,dim=None,max_position_embeddings=2048,base=10000,device=None,scaling_factor=1.0,rope_type="default",config: Optional[Qwen2Config] = None,):super().__init__()# TODO (joao): remove the `if` below, only used for BCself.rope_kwargs = {}if config is None:logger.warning_once("`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the ""`config` argument. All other arguments will be removed in v4.46")self.rope_kwargs = {"rope_type": rope_type,"factor": scaling_factor,"dim": dim,"base": base,"max_position_embeddings": max_position_embeddings,}self.rope_type = rope_typeself.max_seq_len_cached = max_position_embeddingsself.original_max_seq_len = max_position_embeddingselse:# BC: "rope_type" was originally "type"if config.rope_scaling is not None:self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))else:self.rope_type = "default"self.max_seq_len_cached = config.max_position_embeddingsself.original_max_seq_len = config.max_position_embeddingsself.config = configself.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]# 这里会获取到一个shape为[config.hidden_size // config.num_attention_heads//2]的inv_freq# 因为是多头,所以实际上每个头的维度是config.hidden_size // config.num_attention_heads# 再除以2是由公式确定的,具体看下面的公式矩阵.inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)self.register_buffer("inv_freq", inv_freq, persistent=False)self.original_inv_freq = self.inv_freqdef _dynamic_frequency_update(self, position_ids, device):"""dynamic RoPE layers should recompute `inv_freq` in the following situations:1 - growing beyond the cached sequence length (allow scaling)2 - the current sequence length is in the original scale (avoid losing precision with small sequences)"""seq_len = torch.max(position_ids) + 1if seq_len > self.max_seq_len_cached:  # growthinv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len, **self.rope_kwargs)self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilationself.max_seq_len_cached = seq_lenif seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # resetself.register_buffer("inv_freq", self.original_inv_freq, persistent=False)self.max_seq_len_cached = self.original_max_seq_len@torch.no_grad()def forward(self, x, position_ids):"""x: shape为[bs, seq_len, hidden_size]position_ids: shape为[1, seq_len]"""if "dynamic" in self.rope_type:self._dynamic_frequency_update(position_ids, device=x.device)# Core RoPE block# self.inv_freq本身的shape为[32], 经过下面的操作可获得[1, 32, 1]inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)position_ids_expanded = position_ids[:, None, :].float()  # shape为[1, 1, seq_len]# Force float32 (see https://github.com/huggingface/transformers/pull/29285)device_type = x.device.typedevice_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"with torch.autocast(device_type=device_type, enabled=False):# 经过[1, 32, 1]和[1, 1, seq_len]矩阵乘法之后可以得到[1, 32, seq_len]# 再经过变换可以得到[1, seq_len, 32]freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)emb = torch.cat((freqs, freqs), dim=-1)  # [1, seq_len, 64]cos = emb.cos()sin = emb.sin()# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attentioncos = cos * self.attention_scalingsin = sin * self.attention_scalingreturn cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

原位置编码公式:

代码里面得到的是:

注:更多文档请参考Transformer升级之路:2、博采众长的旋转式位置编码[1]

经过 rotary_emb可以得到 position_embeddings,它是一个元组,分别是 (cos, sin)对应的值,它们的 shape都是 [1, seq_len, 64]

**self_attn: **

这里使用 Qwen2SdpaAttention来计算 self_attention,下面我们仔细介绍一下这个模块。

首先是 Qwen2SdpaAttention继承自 Qwen2Attention,然后修改了其forward方法。而 Qwen2Attention初始化方案主要初始化了4个可训练参数权重,分别是 self.q_proj、self.k_proj、self.v_proj、self.o_proj,如下代码:

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  • self.hidden_size=config.hidden_size=896
  • self.num_heads=config.num_attention_heads=14
  • self.head_dim=self.hidden_size // self.num_heads=64
  • self.num_key_value_heads=config.num_key_value_heads=2
  • 注意这里的 q, k, v偏置全部设为了 True,即 bias=True

接着我们看一下 Qwen2Attention中的 forward部分:

def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Cache] = None,output_attentions: bool = False,use_cache: bool = False,cache_position: Optional[torch.LongTensor] = None,position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:"""参数shape说明:hidden_states: [bs, seq_len, hidden_size]attention_mask: [bs, 1, seq_len, seq_len]position_ids: [1, seq_len]cache_position: [seq_len]position_embeddings: 元组数据,即(cos, sin),shape都是[1, seq_len, self.head_dim]"""bsz, q_len, _ = hidden_states.size()# [bs, seq_len, hidden_size]=[bs, seq_len, 896]query_states = self.q_proj(hidden_states)# [bs, seq_len, self.num_key_value_heads * self.head_dim]=[bs, seq_len, 128] key_states = self.k_proj(hidden_states)# [bs, seq_len, self.num_key_value_heads * self.head_dim]=[bs, seq_len, 128]value_states = self.v_proj(hidden_states)# [bs, self.num_heads, seq_len, self.head_dim]=[bs, 14, sql_len, 64]query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)# [bs, self.num_key_value_heads, seq_len, self.head_dim] = [bs, 2, sql_len, 64]key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# [bs, self.num_key_value_heads, seq_len, self.head_dim] = [bs, 2, sql_len, 64]value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)if position_embeddings is None:logger.warning_once("The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ""removed and `position_embeddings` will be mandatory.")cos, sin = self.rotary_emb(value_states, position_ids)else:# cos: [1, seq_len, self.head_dim]=[1, seq_len, 64]# sin: [1, seq_len, self.head_dim]=[1, seq_len, 64]cos, sin = position_embeddings# 针对query_states和key_states运用旋转位置编码,即使用下面的公式。# 得到的shape为[bs, self.num_heads, seq_len, self.head_dim]=[bs, 14, seq_len, 64]# 和 [bs, self.num_key_value_heads, seq_len, self.head_dim] = [bs, 2, seq_len, 64]query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)if past_key_value is not None:cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE modelskey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# repeat k/v heads if n_kv_heads < n_heads# 这里的self.num_key_value_groups=self.num_heads // self.num_key_value_heads=7# num_key_value_groups作用请看下面注释。# 得到的shape为[bs, self.num_heads, seq_len, self.head_dim]=[bs, 14, seq_len, 64]key_states = repeat_kv(key_states, self.num_key_value_groups)value_states = repeat_kv(value_states, self.num_key_value_groups)# 计算attn_weights,其shape是[bs, self.num_head, seq_len, seq_len]attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)if attention_mask is not None:  # no matter the length, we just slice itcausal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_mask# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1,dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# [bs, self.num_head, seq_len, seq_len]矩阵乘以[bs, self.num_head, seq_len, head_dim]# 得到[bs, self.num_head, seq_len, head_dim]attn_output = torch.matmul(attn_weights, value_states)if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"f" {attn_output.size()}")# [bs, seq_len, self.num_head, head_dim]attn_output = attn_output.transpose(1, 2).contiguous()# [bs, seq_len, self.hidden_size]attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)# [bs, seq_len, self.hidden_size]attn_output = self.o_proj(attn_output)if not output_attentions:attn_weights = Nonereturn attn_output, attn_weights, past_key_value

注:

在Transformer模型中,num_key_value_groups 是分组查询注意力(Grouped-query attention, GQA)的一个概念。分组查询注意力是多头注意力的一种改进形式,它在保持一定数量的query头的同时,减少key和value头的数量,以此来提高计算效率。

具体来说,num_key_value_groups 表示的是将key和value头分组的数量。在标准的多头注意力中,每个query头都会与一个对应的key和value头配对。但在分组查询注意力中,多个query头会共享一组key和value头。这样做可以减少模型的参数数量和计算量,从而提高效率。

例如,如果我们有8个query头,但在分组查询注意力中,我们可能只有4个key-value组,那么 num_key_value_groups 就是4。这意味着每两个query头会共享一个key和value头。在实际计算中,这组key和value会被复制(或者说广播)到与query头相同的数量,以便进行注意力权重的计算。

这种方法在保持多头注意力的优势的同时,减少了参数数量和计算复杂度,有助于提升模型的推理速度,尤其是在解码阶段。但是,它也需要仔细设计,以避免对模型性能产生负面影响。

在实际的代码实现中,num_key_value_groups 通常是通过将总的query头数除以key-value头数来计算的。例如,如果 num_heads(query头的数量)是8,而 num_key_value_heads(key-value头的数量)是4,那么 num_key_value_groups 就是2,意味着每两个query头共享一个key-value头。

mlp:

对于这一层,其实直接看代码就可以理解了,没有特别难的内容在里面。所以这里就不进行介绍了。

总结

本篇文章主要集中在介绍数据在流转的过程中,各个矩阵的shape,通过shape的变化,来理解整个过程。其实如果对Bert本身有理解的情况下,整篇内容只需要理解旋转位置编码的实现以及分组查询注意力的理解就好了,其它内容和Bert相比,并没有本质的变化(除了attention_mask部分)。


http://www.ppmy.cn/server/144476.html

相关文章

[Linux] 进程地址空间

1.程序地址空间铺垫&#xff08;以32位为例) a.粗力度验证上面的规则: 这时候我们就可以看出地址由低到高 b.地址有增长方向的验证 栈区变量怎么办? 天然的*heap等这些就是天然的指针变量在栈上的, 堆 上开辟一块空间,由栈上变量指向他栈向下(地址空间减少的方向)增长堆向上增…

论文阅读:SIMBA: single-cell embedding along with features

Chen, H., Ryu, J., Vinyard, M.E. et al. SIMBA: single-cell embedding along with features. Nat Methods 21, 1003–1013 (2024). 论文地址&#xff1a;https://doi.org/10.1038/s41592-023-01899-8 代码地址&#xff1a;https://github.com/pinellolab/simba. 摘要 大多…

全新配置ubuntu18.04深度学习环境

1、下载显卡驱动 1.1、驱动下载 连接&#xff1a;显卡驱动 手动驱动搜索-》查找-》查看-》下载 下载可使用指令 wget https://us.download.nvidia.com/XFree86/Linux-x86_64/535.216.01/NVIDIA-Linux-x86_64-535.216.01.run 2、下载安装cuda12.0 wget https://developer.do…

深度学习实战人脸识别

文章目录 前言一、人脸识别一般过程二、人脸检测主流算法1. MTCNN2. RetinaFace3. CenterFace4. BlazeFace5. YOLO6. SSD7. CascadeCNN 三、人脸识别主流算法1.deepface2.FaceNet3.ArcFace4.VGGFace5.DeepID 四、人脸识别系统实现0.安装教程与资源说明1. 界面采用PyQt5框架2.人…

PHP实现选择排序

选择排序&#xff08;Selection Sort&#xff09;是一种简单直观的排序算法。它的工作原理是&#xff1a;首先在未排序序列中找到最小&#xff08;或最大&#xff09;元素&#xff0c;存放到排序序列的起始位置&#xff0c;然后&#xff0c;再从剩余未排序元素中继续寻找最小&a…

全面解析:单列集合Collection和双列集合Map

Java中的集合&#xff08;Collection&#xff09;是一个框架&#xff0c;用于存储、操作数据。集合框架包括了许多接口和类&#xff0c;用于表示数据的存储方式。集合主要分为两大类&#xff1a;Collection 和 Map。 单列集合Collection的继承体系图&#xff1a; Collection 接…

【什么是RabbitMQ】

RabbitMQ&#xff1a;可靠、灵活的消息中间件 在当今的分布式系统和微服务架构中&#xff0c;消息中间件扮演着至关重要的角色。RabbitMQ&#xff0c;作为一款开源的消息代理软件&#xff0c;以其可靠性、灵活性、可扩展性和多语言支持等特点&#xff0c;在众多消息队列系统中…

SQL99版外连接

外连接 看这样的场景&#xff0c;在ta和tb两表中查询没有对应年龄数据的学生姓名和年龄 SELECT tb.name,ta.age FROM tb INNER JOIN ta ON tb.ta_idta.id WHERE ta.id IS NULL; 结果没有,所以前面的查询是解决不了这种问题&#xff01;&#xff01;&#xff01; 所以外连接…