Transformer解码器终极指南:从Masked Attention到Cross-Attention的PyTorch逐行实现

server/2025/2/13 9:20:03/

Transformer 解码器深度解读 + 代码实战


1. 解码器核心作用

Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列(如翻译结果、文本续写)。它通过 掩码自注意力编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:

  • 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
  • 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。

2. 解码器单层结构详解

每层解码器包含以下模块(附 PyTorch 代码):


2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
class MaskedMultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super().__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads# 生成 Q/K/V 的线性层self.to_qkv = nn.Linear(embed_size, embed_size * 3)self.scale = self.head_dim ** -0.5  # 缩放因子# 输出线性层self.to_out = nn.Linear(embed_size, embed_size)def forward(self, x, mask=None):batch_size, seq_len, _ = x.shape# 生成 Q/K/V 并分割多头qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)# 计算注意力分数 QK^T / sqrt(d_k)attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# 应用下三角掩码(防止看到未来信息)if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)  # 掩码位置填充极小值else:# 自动生成下三角掩码(训练时使用)causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)attn = attn.masked_fill(~causal_mask, -1e10)# Softmax 归一化attn = torch.softmax(attn, dim=-1)# 加权求和out = torch.einsum('bhij,bhjd->bhid', attn, v)out = out.reshape(batch_size, seq_len, self.embed_size)return self.to_out(out)

代码解析

  • causal_mask 生成下三角矩阵(主对角线及以下为1,其余为0),确保解码时仅能看到当前位置及之前的信息。
  • 推理时可手动传递掩码,控制生成长度。

2.2 编码器-解码器交叉注意力(Cross-Attention)
class CrossAttention(nn.Module):def __init__(self, embed_size, heads):super().__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads# 生成 Q 的线性层(解码器输入)self.to_q = nn.Linear(embed_size, embed_size)# 生成 K/V 的线性层(编码器输出)self.to_kv = nn.Linear(embed_size, embed_size * 2)self.scale = self.head_dim ** -0.5self.to_out = nn.Linear(embed_size, embed_size)def forward(self, x, encoder_output, mask=None):batch_size, seq_len, _ = x.shape# 生成 Q 来自解码器输入q = self.to_q(x).view(batch_size, seq_len, self.heads, self.head_dim)# 生成 K/V 来自编码器输出k, v = self.to_kv(encoder_output).chunk(2, dim=-1)k = k.view(batch_size, -1, self.heads, self.head_dim)  # 编码器序列长度可能不同v = v.view(batch_size, -1, self.heads, self.head_dim)# 计算注意力分数 QK^T / sqrt(d_k)attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# 应用掩码(如填充掩码)if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)attn = torch.softmax(attn, dim=-1)out = torch.einsum('bhij,bhjd->bhid', attn, v)out = out.reshape(batch_size, seq_len, self.embed_size)return self.to_out(out)

代码解析

  • Q 来自解码器输入K/V 来自编码器输出,实现跨序列信息融合。
  • 支持自定义掩码(如处理源序列的填充位置)。

2.3 解码器单层完整实现
class TransformerDecoderLayer(nn.Module):def __init__(self, embed_size, heads, dropout=0.1):super().__init__()self.masked_attn = MaskedMultiHeadAttention(embed_size, heads)self.cross_attn = CrossAttention(embed_size, heads)self.ffn = FeedForward(embed_size)  # 复用编码器的FFNself.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.norm3 = nn.LayerNorm(embed_size)self.dropout = nn.Dropout(dropout)def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):# 1. 掩码自注意力masked_attn_out = self.masked_attn(x, tgt_mask)x = x + self.dropout(masked_attn_out)x = self.norm1(x)# 2. 交叉注意力(Q来自x,K/V来自encoder_output)cross_attn_out = self.cross_attn(x, encoder_output, src_mask)x = x + self.dropout(cross_attn_out)x = self.norm2(x)# 3. 前馈网络ffn_out = self.ffn(x)x = x + self.dropout(ffn_out)x = self.norm3(x)return x

3. 完整解码器实现
class TransformerDecoder(nn.Module):def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_size)self.pos_encoding = PositionalEncoding(embed_size)self.layers = nn.ModuleList([TransformerDecoderLayer(embed_size, heads, dropout)for _ in range(layers)])self.fc_out = nn.Linear(embed_size, vocab_size)  # 输出层预测词表概率def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):x = self.embedding(x)x = self.pos_encoding(x)for layer in self.layers:x = layer(x, encoder_output, src_mask, tgt_mask)logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)return logits

4. 实战测试:文本翻译模拟
# 参数设置
vocab_size = 10000  # 目标语言词表大小
embed_size = 512
layers = 6
heads = 8# 初始化编码器和解码器
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
decoder = TransformerDecoder(vocab_size, embed_size, layers, heads)# 模拟输入(源语言句子)
src = torch.randint(0, vocab_size, (32, 20))  # (batch_size=32, src_seq_len=20)
# 编码器输出
encoder_output = encoder(src)# 模拟目标输入(目标语言句子,训练时右移一位)
tgt = torch.randint(0, vocab_size, (32, 25))  # (batch_size=32, tgt_seq_len=25)
# 解码器输出
logits = decoder(tgt, encoder_output)
print("输出形状:", logits.shape)  # torch.Size([32, 25, 10000])

🎉 恭喜! 至此你已经掌握了Transformer解码器的核心原理与实现。无论是机器翻译、文本生成,还是对话系统,解码器都是生成任务的核心引擎。

下一步建议

  1. 尝试在真实数据集(如WMT英德翻译)上训练模型。
  2. 探索 束搜索(Beam Search)温度采样(Temperature Sampling) 等推理优化技术。
  3. 访问 Transformer官方代码库 或 Hugging Face库 深入学习工业级实现。

动手实践是掌握AI的最佳方式——赶紧修改代码参数,观察模型变化吧!如果遇到问题,欢迎在评论区留言讨论,我们一起解决! 🌟


希望这篇解析能助你彻底理解Transformer解码器,期待看到你的实战成果! 😊


http://www.ppmy.cn/server/167289.html

相关文章

PlantUml常用语法

PlantUml常用语法,将从类图、流程图和序列图这三种最常用的图表类型开始。 类图 基础语法 在 PlantUML 中创建类图时,你可以定义类(Class)、接口(Interface)以及它们之间的关系,如继承&#…

MapReduce简单应用(三)——高级WordCount

目录 1. 高级WordCount1.1 IntWritable降序排列1.2 输入输出格式1.3 处理流程 2. 代码和结果2.1 pom.xml中依赖配置2.2 工具类util2.3 高级WordCount2.4 结果 参考 本文引用的Apache Hadoop源代码基于Apache许可证 2.0,详情请参阅 Apache许可证2.0。 1. 高级WordCo…

Haskell语言的云计算

Haskell语言与云计算:结合高阶函数与分布式系统的力量 引言 云计算作为现代计算技术的重要组成部分,已经渗透到我们生活的方方面面。随着技术的不断进步,许多编程语言也开始了它们在云计算领域的探索与实践。Haskell作为一种具有强大类型系…

CP AUTOSAR标准之GPTDriver(AUTOSAR_SWS_GPTDriver)(更新中……)

1 简介和功能概述 该规范指定了AUTOSAR基础软件模块GPT驱动程序的功能、API和配置。   GPT驱动程序是微控制器抽象层(MCAL)的一部分。它初始化并控制微控制器的内部通用定时器(GPT)。   GPT驱动程序提供服务和配置参数 启动和停止硬件计时器获取计时器值控制时间触发的中断…

【Modelsim】medelsim查看仿真覆盖率的方法

最近做项目的时候需要对代码进行仿真覆盖率的分析,那么如何添加仿真覆盖率呢?配置方法如下。 如上图所示进行配置,在调用modelsim的时候就可以显示仿真覆盖率,如下图所示就是modesim的仿真覆盖率。

嵌入式AI革命:DeepSeek开源如何终结GPU霸权,开启单片机智能新时代?

2025年,全球AI领域最震撼的突破并非来自算力堆叠的超级模型,而是中国团队DeepSeek通过开源策略,推动大模型向微型化、低功耗场景的跨越。相对于当人们还在讨论千亿参数模型的训练成本被压缩到600万美金而言,被称作“核弹级别”的操…

Hadoop智能房屋推荐系统 爬虫1w+ 协同过滤余弦函数推荐 代码+视频教程+文档

Hadoop智能房屋推荐系统 爬虫1w 协同过滤余弦函数推荐 带视频教程 毕设设计 课题设计 【Hadoop项目】 1. data.csv上传到hadoop集群环境 2. data.csv数据清洗 3.MapReducer数据汇总处理, 将Reducer的结果数据保存到本地Mysql数据库中 4. SpringbootEchartsMySQL 显示数据分析结…

蓝桥杯试题:归并排序

一、问题描述 在一个神秘的岛屿上,有一支探险队发现了一批宝藏,这批宝藏是以整数数组的形式存在的。每个宝藏上都标有一个数字,代表了其珍贵程度。然而,由于某种神奇的力量,这批宝藏的顺序被打乱了,探险队…