Transformer解码器终极指南:从Masked Attention到Cross-Attention的PyTorch逐行实现

ops/2025/2/12 23:49:05/

Transformer 解码器深度解读 + 代码实战


1. 解码器核心作用

Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列(如翻译结果、文本续写)。它通过 掩码自注意力编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:

  • 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
  • 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。

2. 解码器单层结构详解

每层解码器包含以下模块(附 PyTorch 代码):


2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
class MaskedMultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super().__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads# 生成 Q/K/V 的线性层self.to_qkv = nn.Linear(embed_size, embed_size * 3)self.scale = self.head_dim ** -0.5  # 缩放因子# 输出线性层self.to_out = nn.Linear(embed_size, embed_size)def forward(self, x, mask=None):batch_size, seq_len, _ = x.shape# 生成 Q/K/V 并分割多头qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)# 计算注意力分数 QK^T / sqrt(d_k)attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# 应用下三角掩码(防止看到未来信息)if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)  # 掩码位置填充极小值else:# 自动生成下三角掩码(训练时使用)causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)attn = attn.masked_fill(~causal_mask, -1e10)# Softmax 归一化attn = torch.softmax(attn, dim=-1)# 加权求和out = torch.einsum('bhij,bhjd->bhid', attn, v)out = out.reshape(batch_size, seq_len, self.embed_size)return self.to_out(out)

代码解析

  • causal_mask 生成下三角矩阵(主对角线及以下为1,其余为0),确保解码时仅能看到当前位置及之前的信息。
  • 推理时可手动传递掩码,控制生成长度。

2.2 编码器-解码器交叉注意力(Cross-Attention)
class CrossAttention(nn.Module):def __init__(self, embed_size, heads):super().__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads# 生成 Q 的线性层(解码器输入)self.to_q = nn.Linear(embed_size, embed_size)# 生成 K/V 的线性层(编码器输出)self.to_kv = nn.Linear(embed_size, embed_size * 2)self.scale = self.head_dim ** -0.5self.to_out = nn.Linear(embed_size, embed_size)def forward(self, x, encoder_output, mask=None):batch_size, seq_len, _ = x.shape# 生成 Q 来自解码器输入q = self.to_q(x).view(batch_size, seq_len, self.heads, self.head_dim)# 生成 K/V 来自编码器输出k, v = self.to_kv(encoder_output).chunk(2, dim=-1)k = k.view(batch_size, -1, self.heads, self.head_dim)  # 编码器序列长度可能不同v = v.view(batch_size, -1, self.heads, self.head_dim)# 计算注意力分数 QK^T / sqrt(d_k)attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# 应用掩码(如填充掩码)if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)attn = torch.softmax(attn, dim=-1)out = torch.einsum('bhij,bhjd->bhid', attn, v)out = out.reshape(batch_size, seq_len, self.embed_size)return self.to_out(out)

代码解析

  • Q 来自解码器输入K/V 来自编码器输出,实现跨序列信息融合。
  • 支持自定义掩码(如处理源序列的填充位置)。

2.3 解码器单层完整实现
class TransformerDecoderLayer(nn.Module):def __init__(self, embed_size, heads, dropout=0.1):super().__init__()self.masked_attn = MaskedMultiHeadAttention(embed_size, heads)self.cross_attn = CrossAttention(embed_size, heads)self.ffn = FeedForward(embed_size)  # 复用编码器的FFNself.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.norm3 = nn.LayerNorm(embed_size)self.dropout = nn.Dropout(dropout)def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):# 1. 掩码自注意力masked_attn_out = self.masked_attn(x, tgt_mask)x = x + self.dropout(masked_attn_out)x = self.norm1(x)# 2. 交叉注意力(Q来自x,K/V来自encoder_output)cross_attn_out = self.cross_attn(x, encoder_output, src_mask)x = x + self.dropout(cross_attn_out)x = self.norm2(x)# 3. 前馈网络ffn_out = self.ffn(x)x = x + self.dropout(ffn_out)x = self.norm3(x)return x

3. 完整解码器实现
class TransformerDecoder(nn.Module):def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_size)self.pos_encoding = PositionalEncoding(embed_size)self.layers = nn.ModuleList([TransformerDecoderLayer(embed_size, heads, dropout)for _ in range(layers)])self.fc_out = nn.Linear(embed_size, vocab_size)  # 输出层预测词表概率def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):x = self.embedding(x)x = self.pos_encoding(x)for layer in self.layers:x = layer(x, encoder_output, src_mask, tgt_mask)logits = self.fc_out(x)  # (batch_size, seq_len, vocab_size)return logits

4. 实战测试:文本翻译模拟
# 参数设置
vocab_size = 10000  # 目标语言词表大小
embed_size = 512
layers = 6
heads = 8# 初始化编码器和解码器
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
decoder = TransformerDecoder(vocab_size, embed_size, layers, heads)# 模拟输入(源语言句子)
src = torch.randint(0, vocab_size, (32, 20))  # (batch_size=32, src_seq_len=20)
# 编码器输出
encoder_output = encoder(src)# 模拟目标输入(目标语言句子,训练时右移一位)
tgt = torch.randint(0, vocab_size, (32, 25))  # (batch_size=32, tgt_seq_len=25)
# 解码器输出
logits = decoder(tgt, encoder_output)
print("输出形状:", logits.shape)  # torch.Size([32, 25, 10000])

🎉 恭喜! 至此你已经掌握了Transformer解码器的核心原理与实现。无论是机器翻译、文本生成,还是对话系统,解码器都是生成任务的核心引擎。

下一步建议

  1. 尝试在真实数据集(如WMT英德翻译)上训练模型。
  2. 探索 束搜索(Beam Search)温度采样(Temperature Sampling) 等推理优化技术。
  3. 访问 Transformer官方代码库 或 Hugging Face库 深入学习工业级实现。

动手实践是掌握AI的最佳方式——赶紧修改代码参数,观察模型变化吧!如果遇到问题,欢迎在评论区留言讨论,我们一起解决! 🌟


希望这篇解析能助你彻底理解Transformer解码器,期待看到你的实战成果! 😊


http://www.ppmy.cn/ops/157897.html

相关文章

02.10 TCP之文件传输

1.思维导图 2.作业 服务器代码&#xff1a; #include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <pthread.h> …

Node.js笔记入门篇

黑马程序员视频地址&#xff1a; Node.js与Webpack-01.Node.js入门 基本认识 概念 定义&#xff1a;Node.js 是一个免费、开源、跨平台的 JavaScript 运行时环境, 它让开发人员能够创建服务器 Web 应用、命令行工具和脚本 作用&#xff1a;使用Node.js 编写服务器端程序 ✓ …

输入框相关,一篇文章总结所有前端文本输入的应用场景和实现方法,(包含源码,建议收藏)

前言 本篇文章所有的代码&#xff0c;都是在 vue vite ts 项目基础之上实现的&#xff0c;这样也是为了方便大家直接用源码&#xff0c;在开始之前建议大家阅读这篇《零基础搭建 vite项 目教程》。此项目就是这个教程搭建的&#xff0c;本篇文章关于输入框的相关代码是此项目…

分享一款免费的AI大模型字幕工具,支持语音识别、字幕断句、优化、翻译、视频合成等全流程自动处理(支持抖音、B站、油管等国内外多平台视频下载与处理)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 AI字幕工具:全平台视频创作的福音 📒💡 功能与特点:一网打尽⚙️ 使用⚓️ 相关链接 ⚓️📖 介绍 📖 还在为视频加字幕抓狂?🤯 平台限制多,操作又繁琐?别再挠破头皮啦!今天给大家分享的这款AI神器,简直是视频创…

Vue 3 和 <script setup> 的组件,它使用 v-for 来渲染一个嵌套的菜单结构。

Vue 3 和 <script setup> 的组件&#xff0c;它使用 v-for 来渲染一个嵌套的菜单结构。 [{"id": 1,"title": "Navigator One","children": [{"id": 11, "title": "Item One"},{"id": …

将jar制作成docker镜像运行

将jar制作成docker镜像运行 手动编写 Dockerfile 方式 1. 准备工作 确保你已经安装了 Docker&#xff0c;并且 Docker 服务正在运行。 有一个可运行的 JAR 文件&#xff0c;假设文件名为 your-application.jar。 修改springboot配置文件让日志输出到指定目录下文件中 appli…

网络安全知识--网络、网络安全产品及密码产品概述

网络、网络安全产品及密码产品概述 网络、安全产品网络安全关注重点 网络结构 网络设备&#xff1a;交换机、路由器、负载均衡 安全设备&#xff1a; 通信网络安全类:通信安全、网络监测与控制 区域边界安全类&#xff1a;隔离类、入侵防范、边界访问 安全服务&#xff…

ML.NET库学习003:基于时间序列的共享单车需求预测项目解析

文章目录 ML.NET库学习003&#xff1a;基于时间序列的共享单车需求预测项目解析项目主要目的和原理目的原理 项目概述数据来源工具与框架 Program.cs主要功能和步骤1. 数据加载与预处理2. 特征工程3. 模型训练4. 模型评估5. 模型生成 ModelScoringTester.cs分析与解读方法一&am…