Transformer 代码剖析6 - 位置编码 (pytorch实现)

ops/2025/3/1 20:23:17/

一、位置编码的数学原理与设计思想

1.1 核心公式解析

位置编码采用正弦余弦交替编码方案:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

式中:

  • p o s pos pos:当前词在序列中的绝对位置
  • i i i:特征维度的索引( 0 ≤ i < d m o d e l / 2 0 \leq i < d_{model}/2 0i<dmodel/2
  • 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel:频率控制项,形成指数衰减的频率分布

1.2 设计优势分析

1. 绝对位置感知: 每个位置生成唯一编码模式
2. 相对位置建模: 通过三角函数加法公式可推导任意两个位置的关联度
3. 多频特征捕捉: 不同频率的正余弦波组合形成丰富的表征空间
4. 值域归一化: 所有编码值分布在[-1,1]区间,与词嵌入维度保持数值一致性
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)

二、代码架构与执行流程

2.1 类结构设计

PositionalEncoding
__init__构造函数
创建零矩阵
配置梯度策略
构建位置索引
生成维度索引
计算正弦编码
计算余弦编码
forward前向传播
获取输入尺寸
返回截断编码

2.2 核心代码模块

python">class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len, device):super().__init__()# 编码矩阵初始化(关键参数说明)self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False  # 冻结梯度计算# 位置索引构建(维度变换演示)pos = torch.arange(0, max_len, device=device).float().unsqueeze(dim=1)# 维度索引生成(步长控制逻辑)_2i = torch.arange(0, d_model, step=2, device=device).float()# 编码计算过程(数学实现)self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))def forward(self, x):batch_size, seq_len = x.size()return self.encoding[:seq_len, :]

三、逐行代码深度解析

3.1 构造函数解析

python">def __init__(self, d_model, max_len, device):super(PositionalEncoding, self).__init__()
  • 功能说明:继承PyTorch模块基类,初始化可训练参数
  • 参数详解:
    • d_model:编码维度(需与词嵌入维度一致)
    • max_len:预计算的最大序列长度(如512对应BERT标准配置)
    • device:硬件加速配置(实现跨平台兼容)
python">    self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False
  • 设计意图:创建静态编码矩阵,避免反向传播计算
  • 内存优化:通过requires_grad=False节省显存占用
  • 维度说明:矩阵形状为[max_len, d_model],例如max_len=512时生成512x512矩阵
python">    pos = torch.arange(0, max_len, device=device)pos = pos.float().unsqueeze(dim=1)
  • 位置索引构建:生成[0,1,…,max_len-1]的连续位置序列
  • 维度变换:通过unsqueeze将1D张量转换为2D(max_len,1),便于广播计算
python">    _2i = torch.arange(0, d_model, step=2, device=device).float()
  • 步长控制:step=2确保交替访问奇偶索引
  • 数值范围:当d_model=512时,生成[0,2,4,…,510]的索引序列
python">    self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))
  • 分片赋值:通过0::21::2实现奇偶列交替填充
  • 频率控制:10000 (_2i/d_model)生成指数衰减的频率系数

3.2 前向传播解析

python">def forward(self, x):batch_size, seq_len = x.size()return self.encoding[:seq_len, :]
  • 动态适配:根据实际输入序列长度截取编码
  • 广播机制:自动扩展编码矩阵到批次维度(无需显式复制)
  • 数值叠加:后续与词嵌入进行element-wise相加操作

四、张量运算可视化演示

4.1 示例参数配置

假设:

  • d_model = 4
  • max_len = 3
  • device = 'cpu'

4.2 计算过程推演

步骤1:生成位置索引

pos = [[0],[1],[2]]  # shape (3,1)

步骤2:创建维度索引

_2i = [0, 2]  # d_model=4时step=2生成

步骤3:计算频率项

频率项 = 10000^( (0/4), (2/4) ) = [1, 10000^0.5] ≈ [1, 100]

步骤4:计算位置编码

sin项:
pos / [1, 100] = [[0/1, 0/100],[1/1, 1/100],[2/1, 2/100]]= [[0, 0],[1, 0.01],[2, 0.02]]
sin值:
[[0, 0],[0.8415, 0.00999983],[0.9093, 0.01999867]]cos项计算同理...

最终编码矩阵:

PE = [[sin(0), cos(0), sin(0), cos(0)],      # 位置0[sin(1), cos(0.01), sin(1), cos(0.01)],# 位置1[sin(2), cos(0.02), sin(2), cos(0.02)] # 位置2
]

五、工程实践与优化策略

5.1 配置参数建议

  1. max_len设定:应大于训练数据最大序列长度20%
  2. 设备兼容性:通过device参数统一管理计算设备
  3. 混合精度训练:可将编码矩阵转为half精度

5.2 性能优化技巧

  1. 预计算缓存:提前生成编码矩阵避免运行时计算
  2. 内存映射:对超长序列使用内存映射文件
  3. 稀疏矩阵:对长文本场景采用分块加载策略

六、与其他模块的协同工作

6.1 与词嵌入的集成

python">class TransformerEmbedding(nn.Module):def __init__(self, vocab_size, d_model, max_len, device, dropout):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model, max_len, device)self.dropout = nn.Dropout(dropout)def forward(self, x):tok_emb = self.tok_emb(x)pos_emb = self.pos_emb(x)return self.dropout(tok_emb + pos_emb)
  • 加法融合:通过element-wise相加实现信息融合
  • 梯度隔离:位置编码不参与梯度更新
  • 维度验证:确保tok_embpos_emb维度严格一致

七、典型应用场景分析

7.1 文本生成任务

  • 长序列处理:通过位置编码捕获远距离依赖
  • 解码器优化:在自回归生成时动态调整位置编码

7.2 语音识别系统

  • 时序建模:精确捕捉语音信号的时序特征
  • 多尺度编码:结合不同频率分量处理语音信号

八、扩展研究方向

  1. 相对位置编码:改进绝对位置编码的局限性
  2. 动态频率调整:根据输入数据自动调节频率参数
  3. 混合编码方案:结合可学习参数与固定编码
  4. 量子化压缩:对编码矩阵进行低比特量化

原项目代码(附)

python">"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""import torch
from torch import nn# 定义一个名为PositionalEncoding的类,它继承自nn.Module,用于计算正弦位置编码。
class PositionalEncoding(nn.Module):"""计算正弦位置编码的类。"""def __init__(self, d_model, max_len, device):"""PositionalEncoding类的构造函数。:param d_model: 模型的维度(即嵌入向量的大小)。:param max_len: 序列的最大长度。:param device: 硬件设备设置(CPU或GPU)。"""super(PositionalEncoding, self).__init__()  # 调用父类nn.Module的构造函数。# 初始化一个与输入矩阵大小相同的零矩阵,用于存储位置编码,以便后续与输入矩阵相加。self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False  # 我们不需要计算位置编码的梯度。# 创建一个从0到max_len-1的一维张量,表示序列中的位置索引。pos = torch.arange(0, max_len, device=device)# 将位置索引张量转换为浮点数,并增加一个维度,从1D变为2D,以表示每个位置的索引。pos = pos.float().unsqueeze(dim=1)# 1D => 2D,增加维度以表示单词的位置。# 创建一个从0到d_model-1,步长为2的一维浮点数张量,用于计算正弦和余弦函数的指数部分。_2i = torch.arange(0, d_model, step=2, device=device).float()# 'i'表示d_model的索引(例如,嵌入大小=50时,'i'的范围为[0,50])。# "step=2"意味着'i'每次增加2(相当于2*i)。# 使用正弦函数计算位置编码的偶数索引位置的值。self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))# 使用余弦函数计算位置编码的奇数索引位置的值。self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))# 计算位置编码,以考虑单词的位置信息。def forward(self, x):# self.encoding是预先计算好的位置编码矩阵。# [max_len = 512, d_model = 512],表示最大长度为512,维度为512的位置编码。# 获取输入x的批次大小和序列长度。batch_size, seq_len = x.size()# [batch_size = 128, seq_len = 30],表示批次大小为128,序列长度为30。# 返回与输入序列长度相匹配的位置编码。return self.encoding[:seq_len, :]# [seq_len = 30, d_model = 512],返回的形状为序列长度乘以维度。# 它将与输入嵌入(tok_emb)相加,tok_emb的形状通常为[128, 30, 512]。

http://www.ppmy.cn/ops/162319.html

相关文章

Flutter 学习之旅 之 flutter 在 Android 端进行简单的图片裁剪操作

Flutter 学习之旅 之 flutter 在 Android 端进行简单的图片裁剪操作 目录 Flutter 学习之旅 之 flutter 在 Android 端进行简单的图片裁剪操作 一、简单介绍 二、简单介绍 image_cropper 三、安装 image_picker 四、简单案例实现 五、关键代码 一、简单介绍 Flutter 是一…

vue+element ui 实现选择季度组件

1、实现效果 和element ui 选择月份同样的效果&#xff0c;也支持键盘上下左右控制选择季度。 2、页面调用 <ElQuarterPicker v-model"value" placeholder"选择季度"></ElQuarterPicker> 3、组件代码&#xff1a; ElQuarterPicker.vue &l…

华为在不同发展时期的战略选择(节选)

华为在不同发展时期的战略选择&#xff08;节选&#xff09; 添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; 来源&#xff1a;谢宁专著《华为战略管理法&#xff1a;DSTE实战体系》。本文有节选修改。 导言 从目前所取得的成就往回看&#xff0c;华为…

C语言【进阶篇】之指针——涵盖基础、数组与高级概念

目录 &#x1f680;前言&#x1f914;指针是什么&#x1f31f;指针基础&#x1f4af;内存与地址&#x1f4af;指针变量&#x1f4af; 指针类型&#x1f4af;const 修饰指针&#x1f4af;指针运算&#x1f4af;野指针和 assert 断言 &#x1f4bb;数组与指针&#x1f4af;数组名…

深度学习基础--ResNet50V2网络的讲解,ResNet50V2的复现(pytorch)以及用复现的ResNet50做鸟类图像分类

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 如果说最经典的神经网络&#xff0c;ResNet肯定是一个&#xff0c;从ResNet发布后&#xff0c;作者又进行修改&#xff0c;命名为ResNe50v2&#xff0c…

Cursor+pycharm接入Codeuim(免费版),Tab自动补全功能平替

如题&#xff0c;笔者在Cursor中使用pycharm写python程序&#xff0c;试用期到了Tab自动补全功能就不能用了&#xff0c;安装Codeuim插件可以代替这个功能。步骤如下&#xff1a; 1. 在应用商店中搜索扩展Codeuim&#xff0c;下载安装 2. 安装完成后左下角会弹出提示框&#x…

用HTML5+CSS+JavaScript实现新奇挂钟动画

用HTML5+CSS+JavaScript实现新奇挂钟动画 引言 在技术博客中,如何吸引粉丝并保持他们的关注?除了干货内容,独特的视觉效果也是关键。今天,我们将通过HTML5、CSS和JavaScript实现一个新奇挂钟动画,并将其嵌入到你的网站中。这个动画不仅能让你的网站脱颖而出,还能展示你的…

GitHub开源协议选择指南:如何为你的项目找到最佳“许可证”?

引言 当你站在GitHub仓库创建的十字路口时&#xff0c;是否曾被众多开源协议晃花了眼&#xff1f; 别担心&#xff01;这篇指南将化身你的"协议导航仪"&#xff0c;用一张流程图五个灵魂拷问&#xff0c;帮你轻松找到最佳选择。无论你是开发者、开源爱好者&#xff…