llama源码学习·model.py[3]ROPE旋转位置编码(4)ROPE的应用

devtools/2025/3/24 8:16:28/

一、源码注释

def apply_rotary_emb(xq: torch.Tensor, # 查询矩阵xk: torch.Tensor, # 键矩阵freqs_cis: torch.Tensor, # 旋转嵌入
) -> Tuple[torch.Tensor, torch.Tensor]:# 首先将xq和xk张量转换为浮点数# 然后使用reshape将最后一个维度拆分为两个维度,每个维度都有大小为2,这样做是为了为复数张量提供实部和虚部。# 然后,torch.view_as_complex用于从实部和虚部创建复数张量# *xq.shape[:-1] 是保留原始形状的所有维度,除了最后一个维度。# -1 是一个占位符,它告诉PyTorch自动计算这个维度,以保持元素总数不变。# 2 是最后一个维度,这是为了为接下来的复数转换做准备。每个复数由两个浮点数表示(实部和虚部),所以最后一个维度是2。xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))# 将freqs_cis重新reshape以匹配xq_的形状,以便进行广播运算。freqs_cis = reshape_for_broadcast(freqs_cis, xq_)# 这两行代码将查询和键张量与旋转嵌入相乘,应用位置嵌入。# 函数计算xq_和xk_与freqs_cis的元素乘积(这是一个复数乘法),# 在复数乘法中,xq_和xk_的实部和虚部会分别与freqs_cis的实部和虚部进行乘法运算。# flatten(3) 将两个最后的维度合并回一个维度xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)# 函数返回经过旋转嵌入处理的查询和键张量,同时确保它们的数据类型与原始输入相匹配。return xq_out.type_as(xq), xk_out.type_as(xk)

二、举例说明

# query矩阵
xq = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  
# key矩阵
xk = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# 频率张量
freqs_cis = torch.tensor([[1.0000+0.0000j], [1.0000+0.0000j]])  

*** xq.shape: *** torch.Size([2, 2, 2])

*** xk.shape: *** torch.Size([2, 2, 2])

freqs_cis.shape: torch.Size([2, 1])

# 首先,apply_rotary_emb函数会将query和key矩阵reshape并转换为复数张量。
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

xq.float().reshape(*xq.shape[:-1], *-*1, 2).shape: torch.Size([2, 2, 1, 2])

xk.float().reshape(*xk.shape[:-1], *-*1, 2).shape: torch.Size([2, 2, 1, 2])

xq_.shape: torch.Size([2, 2, 1])

xk_.shape: torch.Size([2, 2, 1])

# freqs_cis 的形状是 (2, 1),xq_ 的形状是(2, 2, 1), 所以我们需要将freqs_cis形状调整为 (1, 2, 2, 1)
freqs_cis_new = reshape_for_broadcast(freqs_cis, xq_)

freqs_cis_new.shape : freqs_cis_new.shape

# 函数会将输入复数张量与频率张量相乘。
xq_out_complex = xq_ * freqs_cis_new
xk_out_complex = xk_ * freqs_cis_new

xq_out_complex: tensor([[[1.+2.j], [3.+4.j]], [[5.+6.j], [7.+8.j]]])

# 将结果重塑并转换回实数张量。
xq_out = torch.view_as_real(xq_out_complex).flatten(3)
xk_out = torch.view_as_real(xk_out_complex).flatten(3)

xq_out: tensor([[[[ 1., 2.], [ 6., 8.]], [[15., 18.], [28., 32.]]]])


http://www.ppmy.cn/devtools/169459.html

相关文章

《南京日报》专题报道 | 耘瞳科技“工业之眼”加码“中国智造”

在江宁开发区,机器人已不再是科幻电影里的遥远想象,他们就像人类的“同事”,在工地上忙着贴砖、刷墙、搬运、检测; 在体育训练场上帮助运动员矫正姿势; 在医院里帮助医生发现帕金森早期征兆,在智慧工厂里…

C# 派生 详解

1.1派生 继承设计的目的:经常需要扩展现有类型来添加功能(行为和数据)。 定义派生类 要在类标识符后添加冒号,接着添加基类名称。 注意:1.通过继承,基类的每个成员都出现在派生类构成的链条中。 2.除非明…

开源模型应用落地-LangGraph101-多智能体协同实践(六)

一、前言 随着人工智能技术的快速发展,如何高效处理复杂任务成了 AI 系统的一大挑战。传统的线性架构在面对多轮对话和动态决策时常常显得无能为力。而 LangGraph 这种多智能体合作框架的出现,为这个问题提供了新的解决方案。 相关文章: 开源模型应用落地-LangGraph101-探索…

react 常用插件

ts项目中如果提示 path不存需要安装 pnpm i types/node --save-dev常用插件 axios ajax请求echarts 图表插件reduxjs/toolkit redux 插件antd ui插件nprogress页面上方或者下方加载loding 常用语路由跳转使用dayjs 日期格式转换react-quill-new 富文本组件 全面兼容react 18r…

【STM32】USART串口协议串口外设-学习笔记

串口协议 通信接口 通信的目的:将一个设备的数据传送到另一个设备,扩展硬件系统。比如STM32芯片内部集成了很多功能模块,像定时器计数、PWM输出、AD采集等等。这些都是芯片内部的电路,这些电路的配置寄存器,数据寄存…

深入解析:Nginx+Keepalived实现双机热备架构

全文目录: 开篇语前言摘要概述什么是双机热备?为什么选择 Nginx Keepalived?本文目标 架构设计与原理架构示意图工作原理 环境准备系统与软件环境基础网络配置 实战:Nginx Keepalived 双机热备配置第一步:安装 Nginx…

mac npm run dev报错 error:0308010C:digital envelope routines::unsupported

并且提示 Unsupported engine { npm WARN EBADENGINE package: achrinza/node-ipc9.2.2, npm WARN EBADENGINE required: { node: 8 || 10 || 12 || 14 || 16 || 17 }, npm WARN EBADENGINE current: { node: v18.18.0, npm: 9.8.1 } npm WARN EBADENGINE } package.jso…

生成PDF文件:从html2canvas和jsPdf渲染到Puppeteer矢量图

刚刚实现而已:第一次明白,双击或file:///打开html文件,居然和从localhost:3000打开同一个html文件有本质的区别。 字体居然还能以Base64代码嵌入到网页,只是太大太笨。 需要安装node.js,npm安装更多依赖:…