【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

news/2025/2/2 21:22:47/

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

import torch
import mathdef positional_encoding(pos, d_model):pe = torch.zeros(1, d_model)for i in range(0, d_model, 2):pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))return pe# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) …, (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],[sin(pos * θ_0),  cos(pos * θ_0)]]  [[cos(pos * θ_1), -sin(pos * θ_1)],[sin(pos * θ_1),  cos(pos * θ_1)]]...[[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],[sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

import torch
import mathdef precompute_freqs(dim, end, theta=10000.0):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(end)freqs = torch.outer(t, freqs)return torch.cat((freqs, freqs), dim=1)def apply_rotary_emb(xq, xk, freqs):xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))xq_rotated = xq_complex * freqs_complexxk_rotated = xk_complex * freqs_complexreturn xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta):
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs):
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

  • RoFormer: Enhanced Transformer with Rotary Position Embedding
  • Attention is All You Need

http://www.ppmy.cn/news/1568785.html

相关文章

智能家居环境监测系统设计(论文+源码)

1. 系统方案 系统由9个部分构成,分别是电源模块、烟雾传感器模块、GSM发送短信模块、报警模块、温度传感器模块、人体红外感应模块、按键设置模块、显示模块、MCU模块。各模块的作用如下:电源模块为系统提供电力;烟雾传感器模块检测烟雾浓度&…

快速提升网站收录:利用网站FAQ页面

本文转自:百万收录网 原文链接:https://www.baiwanshoulu.com/48.html 利用网站FAQ(FrequentlyAskedQuestions,常见问题解答)页面是快速提升网站收录的有效策略之一。以下是一些具体的方法和建议,以帮助你…

docker如何查看容器启动命令(已运行的容器)

docker ps 查看正在运行的容器 该命令主要是为了详细展示查看运行时的command参数 # 通过docker --no-trunc参数来详细展示容器运行命令 docker ps -a --no-trunc | grep <container_name>通过docker inspect命令 使用docker inspect&#xff0c;但是docker inspect打…

单机伪分布Hadoop详细配置

目录 1. 引言2. 配置单机Hadoop2.1 下载并解压JDK1.8、Hadoop3.3.62.2 配置环境变量2.3 验证JDK、Hadoop配置 3. 伪分布Hadoop3.1 配置ssh免密码登录3.2 配置伪分布Hadoop3.2.1 修改hadoop-env.sh3.2.2 修改core-site.xml3.2.3 修改hdfs-site.xml3.2.4 修改yarn-site.xml3.2.5 …

FreeRTOS从入门到精通 第十八章(Tickless低功耗模式)

参考教程&#xff1a;【正点原子】手把手教你学FreeRTOS实时系统_哔哩哔哩_bilibili 一、低功耗模式概述 1、低功耗模式简介 &#xff08;1&#xff09;一般MCU都有相应的低功耗模式&#xff0c;裸机开发时可以使用MCU的低功耗模式。 &#xff08;2&#xff09;FreeRTOS提供…

C++编程语言:抽象机制:模板(Bjarne Stroustrup)

目录 23.1 引言和概观(Introduction and Overview) 23.2 一个简单的字符串模板(A Simple String Template) 23.2.1 模板的定义(Defining a Template) 23.2.2 模板实例化(Template Instantiation) 23.3 类型检查(Type Checking) 23.3.1 类型等价(Type Equivalence) …

基于Springboot + vue实现的美发门店管理系统

💖学习知识需费心, 📕整理归纳更费神。 🎉源码免费人人喜, 🔥码农福利等你领! 💖常来我家多看看, 📕网址:扣棣编程, 🎉感谢支持常陪伴, 🔥点赞关注别忘记! 💖山高路远坑又深, 📕大军纵横任驰奔, 🎉谁敢横刀立马行? 🔥唯有点赞+关注成! �…

油漆面积——蓝桥杯

1.题目描述 X 星球的一批考古机器人正在一片废墟上考古。 该区域的地面坚硬如石、平整如镜。 管理人员为方便&#xff0c;建立了标准的直角坐标系。 每个机器人都各有特长、身怀绝技。它们感兴趣的内容也不相同。 经过各种测量&#xff0c;每个机器人都会报告一个或多个矩…