文章目录
- 重点讲解
- 代码实现
- 流程图(维度变换示意图)
- self-attention示例
- 加入mask示例
MultiHeadAttention 是 Transformer 模型中的一个核心组件,它允许模型在处理序列的每个位置时同时考虑来自多个“视角”(即头部)的信息。这样做可以提高模型对不同位置关系的理解能力。
重点讲解
主要步骤:
- 线性变换得到QKV,并将QKV分割为多头
- 计算缩放点积注意力(注意mask可选)
- 拼接多头
- 最后再进行一次线性变换
代码实现
下面,我将使用 PyTorch 框架实现一个基本的 MultiHeadAttention
模块。
python">import torch
import torch.nn as nn
import torch.nn.functional as Fimport mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.depth = d_model // num_heads# 定义线性层和输出线性层self.query_linear = nn.Linear(d_model, d_model)self.key_linear = nn.Linear(d_model, d_model)self.value_linear = nn.Linear(d_model, d_model)self.final_linear = nn.Linear(d_model, d_model)def split_heads(self, x, batch_size):"""分割最后一个维度到 (num_heads, depth).转置结果使得形状为 (batch_size, num_heads, seq_length, depth)"""x = x.view(batch_size, -1, self.num_heads, self.depth)return x.permute(0, 2, 1, 3)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# 1. 线性层和分割到多头query = self.split_heads(self.query_linear(query), batch_size)key = self.split_heads(self.key_linear(key), batch_size)value = self.split_heads(self.value_linear(value), batch_size)# 2. 缩放点积注意力scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)if mask is not None:scores = scores.masked_fill(mask == True, float('-inf'))attention_weights = F.softmax(scores, dim=-1)# 3. 将注意力权重应用到值上output = torch.matmul(attention_weights, value)# 4. 连接头部output = output.permute(0, 2, 1, 3).contiguous()output = output.view(batch_size, -1, self.d_model)# 5. 最后一次线性变换output = self.final_linear(output)return output
流程图(维度变换示意图)
self-attention示例
python"> d_model = 512 # 模型维度num_heads = 8 # 头数mha = MultiHeadAttention(d_model, num_heads)# 创建随机数据batch_size = 4seq_length = 60x = torch.rand(batch_size, seq_length, d_model) # 输入假设维度为 (batch_size, seq_length, d_model)output = mha(x, x, x) # 自注意力机制,qkv的输入相同;而cross-attention中,query来自decoder,kv来自encoderprint(output.shape)
加入mask示例
解码器的自注意力层需要确保当前位置只能注意到前面的位置(包括当前位置),而不是未来的位置。这通常通过一个未来位置掩码实现,它是一个下三角矩阵。
python">import torchdef generate_square_subsequent_mask(seq_len):"""生成一个未来步骤掩码,用于解码器中防止看到未来信息。"""mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()return maskd_model = 512 # 模型维度
num_heads = 8 # 头数
mha = MultiHeadAttention(d_model, num_heads)# 创建随机数据
batch_size = 4
seq_length = 60
x = torch.rand(batch_size, seq_length, d_model) # 输入假设维度为 (batch_size, seq_length, d_model)# 生成掩码并将其应用于解码器的自注意力层
future_mask = generate_square_subsequent_mask(seq_length).to(x.device)
output = mha(x, x, x, mask=future_mask) # 自注意力机制
print(output.shape) # 应为 (batch_size, seq_length, d_model)
注意广播机制
在 PyTorch 中,masked_fill
函数可以很灵活地处理维度差异情况,通过广播(broadcasting)机制来匹配维度。