【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

ops/2025/2/3 13:07:27/

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

import torch
import mathdef positional_encoding(pos, d_model):pe = torch.zeros(1, d_model)for i in range(0, d_model, 2):pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))return pe# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) …, (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],[sin(pos * θ_0),  cos(pos * θ_0)]]  [[cos(pos * θ_1), -sin(pos * θ_1)],[sin(pos * θ_1),  cos(pos * θ_1)]]...[[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],[sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

import torch
import mathdef precompute_freqs(dim, end, theta=10000.0):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(end)freqs = torch.outer(t, freqs)return torch.cat((freqs, freqs), dim=1)def apply_rotary_emb(xq, xk, freqs):xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))xq_rotated = xq_complex * freqs_complexxk_rotated = xk_complex * freqs_complexreturn xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta):
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs):
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

  • RoFormer: Enhanced Transformer with Rotary Position Embedding
  • Attention is All You Need

http://www.ppmy.cn/ops/155315.html

相关文章

基于 STM32 的智能农业温室控制系统设计

1. 引言 随着农业现代化的发展,智能农业温室控制系统对于提高农作物产量和质量具有重要意义。该系统能够实时监测温室内的环境参数,如温度、湿度、光照强度和土壤湿度等,并根据这些参数自动调节温室设备,如通风扇、加热器、加湿器…

开源智慧园区管理系统对比其他十种管理软件的优势与应用前景分析

内容概要 在当今数字化快速发展的时代,园区管理软件的选择显得尤为重要。而开源智慧园区管理系统凭借其独特的优势,逐渐成为用户的新宠。与传统管理软件相比,它不仅灵活性高,而且具有更强的可定制性,让各类园区&#…

hive:数据导入,数据导出,加载数据到Hive,复制表结构

hive不建议用insert,因为Hive是建立在Hadoop之上的数据仓库工具,主要用于批处理和大数据分析,而不是为OLTP(在线事务处理)操作设计的。INSERT操作会非常慢 数据导入 命令行界面:建一个文件 查询数据>>复制>>粘贴到新…

selenium自动化测试框架——面试题整理

目录 1. 什么是 Selenium?它的工作原理是什么? 2. Selenium 主要组件 3. 常见 WebDriver 驱动 4. Selenium 如何驱动浏览器? 5. WebDriver 协议是什么? 6. Page Object 模式与 Page Factory 7. 如何判断元素是否可见&#x…

Python 类型注解

文章目录 Python 类型注解详解1. 引言2. Python 类型注解基础2.1 变量类型注解2.2 函数参数和返回值注解2.3 typing 模块的支持 3. 进阶:复杂数据类型3.1 可选类型(Optional)3.2 联合类型(Union)3.3 泛型(G…

AD电路仿真

目录 0 前言 仿真类型 仿真步骤 仿真功能及参数设置 仿真模型 应用优势 1 新建原理图 2 放置元器件及布线 3 放置探头 4 实验结果 Operating Point 分析的作用 DC Sweep 的主要功能 Transient Analysis 的主要功能 AC Analysis 的功能 5 总结 1. 直流工作点分析…

【含文档+PPT+源码】基于微信小程序连锁药店商城

项目介绍 本课程演示的是一款基于微信小程序连锁药店商城,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该项目附带的…

Deepseek的RL算法GRPO解读

在本文中,我们将深入探讨Deepseek采用的策略优化方法GRPO,并顺带介绍一些强化学习(Reinforcement Learning, RL)的基础知识,包括PPO等关键概念。 策略函数(policy) 在强化学习中, a…