【多模态大模型】LLaMA in arXiv 2023

ops/2024/9/18 12:45:43/ 标签: llama, SwiGLU, RoPE, FlashAttention, RMSNorm, 原理, 代码

一、引言

论文: LLaMA: Open and Efficient Foundation Language Models
作者: Meta AI
代码 LLaMA
特点: 该方法在Transformer的基础上增加了Pre-normalization (RMSNorm)、SwiGLU activation function (SwiGLU)、Rotary Embeddings (RoPE)、FlashAttention

⚠️ 在学习该方法前,建议补充BatchNorm、LayerNorm、位置编码、Attention的相关知识。

二、详情

Transformer和LLaMA的结构图如下:

可见,其结构差异主要体现在如下方面:

  • Transformer采用了左编码器+右解码器(Encoder+Decoder)的结构,LLaMA采用了仅解码器(Decoder-only)的结构。由于仅包含解码器不需要与编码器输出交互,故LLaMA去掉了Transformer中Decoder中间的交叉Multi-Head Attention和Add & Norm。
  • LLaMA采用了归一化前置(Pre-normalization)的策略,将归一化操作放在了注意力、FFN前并在线性映射前增加了一个归一化。此外,LLaMA还将LayerNorm替换为了RMSNorm
  • LLaMA将绝对位置编码替换为了旋转位置编码,即RoPE,这是一种只对Q和K进行位置编码的方式。
  • 为加速训练,LLaMA引入了FlashAttention
  • LLaMA将ReLU替换为了SwiGLU

RMSNorm_18">2.1 RMSNorm

均方根归一化RMSNorm简化了LayerNorm的计算。

要了解RMSNorm,首先需回顾LayerNorm的公式:

其中, x \boldsymbol{x} x为输入的token序列, E [ x ] = 1 n ∑ i = 1 n x i {\bf E}\boldsymbol{[x]}=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i E[x]=n1i=1nxi V a r [ x ] = 1 n ∑ i = 1 n ( x i − E [ x ] ) 2 {\bf Var}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^n(\boldsymbol{x}_i-{\bf E}\boldsymbol{[x]})^2} Var[x]=n1i=1n(xiE[x])2 x \boldsymbol{x} x的均值和有偏方差, ϵ \boldsymbol{\epsilon} ϵ用来防止分母为0, γ \boldsymbol{\gamma} γ β \boldsymbol{\beta} β是可学习的参数用来缩放和平移。

RMSNorm简化了LayerNorm的计算,其公式如下:

其中, R M S [ x ] = 1 n ∑ i = 1 n x i 2 {\bf RMS}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i^2} RMS[x]=n1i=1nxi2 是均方根。

可见,RMSNormLayerNorm主要有如下差别:

  • RMSNorm无需计算均值 E [ x ] {\bf E}[\boldsymbol{x}] E[x]
  • RMSNorm将有偏方差 V a r [ x ] {\bf Var[\boldsymbol{x}]} Var[x]替换为了均方根 R M S [ x ] {\bf RMS[\boldsymbol{x}]} RMS[x]
  • RMSNorm无需平移项 γ \boldsymbol{\gamma} γ

LayerNorm一样,RMSNorm也能以句子或单词(token)为单位进行归一化,如下给出了以token为单位的代码示例。

import torch
import torch.nn as nnclass MyRMSNorm(nn.Module):def __init__(self, hidden_dim, eps=1e-8):super().__init__()# 防止分母计算为0self._eps = eps# 仿射变换参数,缩放norm后的数据分布self._gamma = nn.Parameter(torch.ones(hidden_dim))def forward(self, input):# input(N,L,C)ms = input.pow(2).mean(dim=-1, keepdim=True)  # 计算均方,token-wiseinput = input / torch.sqrt(ms + self._eps)  # 执行标准化return input * self._gamma  # 仿射变换if __name__ == '__main__':batch_size = 4length = 2hidden_dim = 3input = torch.rand(4, 2, 3)myRMSN = MyRMSNorm(hidden_dim=hidden_dim)MyO = myRMSN(input)pytorchRMSN = nn.RMSNorm(normalized_shape=hidden_dim, elementwise_affine=False)  # 不使用可学习的gamma和betapytorchO = pytorchRMSN(input)print(MyO == pytorchO)

RoPE_72">2.2 RoPE

旋转位置编码RoPE使用绝对位置信息设计旋转规则,使旋转后的数据能够表达相对位置信息。

要了解RoPE,首先我们来了解一下二维空间的旋转。如下图:

其中, X = [ ρ cos ⁡ ϕ , ρ sin ⁡ ϕ ] X=[\rho\cos\phi,\rho\sin\phi] X=[ρcosϕ,ρsinϕ]是一个二维向量,逆时针旋转 θ \theta θ度变成 X R ( θ ) XR(\theta) XR(θ)。此时 R ( θ ) = [ cos ⁡ θ , sin ⁡ θ − sin ⁡ θ , cos ⁡ θ ] R(\theta)=\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right] R(θ)=[cosθ, sinθsinθ, cosθ],证明如下:

X R ( θ ) = [ ρ cos ⁡ ϕ , ρ sin ⁡ ϕ ] [ cos ⁡ θ , sin ⁡ θ − sin ⁡ θ , cos ⁡ θ ] = ρ [ cos ⁡ ϕ cos ⁡ θ − sin ⁡ ϕ sin ⁡ θ , cos ⁡ ϕ sin ⁡ θ + sin ⁡ ϕ cos ⁡ θ ] = [ ρ cos ⁡ ( ϕ + θ ) , ρ sin ⁡ ( ϕ + θ ) ] XR(\theta)=[\rho\cos\phi,\rho\sin\phi]\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right]\\=\rho[\cos\phi\cos\theta-\sin\phi\sin\theta,\cos\phi\sin\theta+\sin\phi\cos\theta]=[\rho\cos(\phi+\theta),\rho\sin(\phi+\theta)] XR(θ)=[ρcosϕ,ρsinϕ][cosθ, sinθsinθ, cosθ]=ρ[cosϕcosθsinϕsinθ,cosϕsinθ+sinϕcosθ]=[ρcos(ϕ+θ),ρsin(ϕ+θ)]

可见, X X X X R ( θ ) XR(\theta) XR(θ)仅差一个 θ \theta θ,所以二维空间逆时针旋转 θ \theta θ度可通过 R ( θ ) R(\theta) R(θ)实现。

旋转只改变角度,不改变长度。

RoPE将旋转应用在了注意力模块的查询 Q Q Q K K K上。它将第 i i i个查询 Q i Q_i Qi旋转 i θ i\theta iθ的角度,再将第 j j j个键 K j K_j Kj旋转 j θ j\theta jθ的角度,那么 Q i K j T Q_iK_j^T QiKjT就会变成一个与相对位置 i − j i-j ij相关的值。推导过程如下:

i i i j j j是查询 Q i Q_i Qi K j K_j Kj的绝对位置, i − j i-j ij是它们的相对位置。

然而, Q i Q_i Qi K j K_j Kj的维度通常都是大于2的,我们假设它是 D D D D D D是2的整数倍,于是我们可以将 Q i Q_i Qi K j K_j Kj分别划分为 d = D 2 d=\frac{D}{2} d=2D个子空间,每个子空间都是二维的。

下图给出了一个 D = 10 D=10 D=10的例子,我们将 Q i Q_i Qi K j K_j Kj分为5个子空间并分配1个包括5个角度的旋转序列 Θ = ( θ 1 , θ 2 , ⋯ , θ 5 ) \Theta=(\theta_1,\theta_2,\cdots,\theta_5) Θ=(θ1,θ2,,θ5),每个子空间的旋转角度是在对应旋转序列的基础上乘以 i i i j j j

将其扩展到 d d d个子空间,可以得到如下信息:

其中, X i X_i Xi代指 Q i Q_i Qi K j K_j Kj。此时,这种旋转仍然具有相对位置的表达能力,证明如下:

显然,上面的 R ( i Θ ) R(i\Theta) R(iΘ)过于稀疏,为了提升计算效率,通常 d d d个子空间的旋转使用下式表达:

为避免token数过多, i θ k i\theta_k iθk j θ k j\theta_k jθk重叠导致相对位置得不到表达(同一个子空间 k k k,绝对位置 i i i j j j不同, i θ k − j θ k = 2 m π i\theta_k-j\theta_k=2m\pi iθkjθk=2时重叠, m m m是一个整数),RoPE使用了一个递减的等比数列作为 θ \theta θ序列,如下:

θ k \theta_k θk是递减的,这表示token中前几个子空间的旋转角度较大,越往后旋转角度越小。

事实上,为了方便我们通常不是将相邻的两个值划分至同一子空间,而是将D分为前后两个部分,前后各取一个依次组成子空间,例如[q0,q1,q2,q3]被划分为[q0,q2], [q1,q3]而不是[q0,q1], [q2,q3]。以下为使用这种方式进行子空间划分的RoPE代码

from torch.nn import functional as F
import torch.nn as nn
import torch
import mathclass Rotator:"""根据hidden_dim,和position_ids 生成对应的旋转位置编码, 和论文中定义略有不同,一个个二维的子空间被分割到了前后两部分,分别进行旋转,然后拼接起来"""def __init__(self, D, position_ids):""" position_ids: [seq_len], D 和单个头的hidden_dim对应 """base = 10000d = D / 2B = base ** (1/d)theta_base = 1.0 / (B ** (torch.arange(0, d)))    # 等比数列, $\Theta$thetas = position_ids.outer(theta_base)  # [seq_len, D/2]# 这里的子空间划分与讲解不同,[q0,q1,q2,q3] -> [q0,q2],[q1,q3]是两个子空间而不是[q0,q1],[q2,q3]full_thetas = torch.cat((thetas, thetas), dim=-1)  # [seq_len, D]self.cos = full_thetas.cos()self.sin = full_thetas.sin()def rotate(self, x):"""x: [bs, num_attention_heads, seq_len, D]q: [bs, num_attention_heads, seq_len, D]cos: [seq_len, D][x,y] @ [[cos, sin], [-sin, cos]] = [x*cos-y*sin, ycos+x*sin] =[x,y]*cos+[-y, x]*sin"""return x * self.cos + Rotator.reverse_half(x) * self.sin@staticmethoddef reverse_half(q):""" q: [bs, num_attention_heads, seq_len, D] trick2 """u = q[..., :q.shape[-1] // 2]  # 认为是各个二维子空间的第一维的向量集结v = q[..., q.shape[-1] // 2:]   # 认为是各个二维子空间的第二维的向量集结return torch.cat((-v, u), dim=-1)if __name__ == "__main__":batch_size = 2num_heads = 3D = 6  # 单个头的token向量长度hidden_dim = D * num_headsseq_len = 4position_ids = torch.arange(seq_len)rotator = Rotator(D, position_ids)x = torch.randn((batch_size, seq_len, hidden_dim))# 对每个头分别进行旋转,[batch_size,seq_len,hidden_dim] -> [batch_size,seq_len,num_heads,D] -> [batch_size,num_heads,seq_len,D]x = x.view(batch_size, seq_len, num_heads, D).transpose(1, 2)x = rotator.rotate(x)

FlashAttention_168">2.3 FlashAttention

FlashAttention以分块的形式进行注意力计算,避免了SRAM和HBM之间频繁读写导致的时间浪费。

详情请参考我之前的博客FlashAttention in NeurIPS 2022。

SwiGLU_172">2.4 SwiGLU

激活函数SwiGLU是门控线性单元(Gated Linear Units, GLU)的变体,下图红框中表达了GLU的计算过程:

可见,GLU会先使用两个带偏执的线性层映射输入 x \boldsymbol{x} x,分别记为 x W 1 + b 1 \boldsymbol{xW_1+b_1} xW1+b1 x W 2 + b 2 \boldsymbol{xW_2+b_2} xW2+b2;其中一个线性映射后会跟一个非线性激活函数sigmoid,记为 σ ( x W 1 + b 1 ) \sigma(\boldsymbol{xW_1+b_1}) σ(xW1+b1);然后将左右两边的结果对应元素相乘即完成了GLU,记为 σ ( x W 1 + b 1 ) ⊗ ( x W 2 + b 2 ) \sigma(\boldsymbol{xW_1+b_1})\otimes(\boldsymbol{xW_2+b_2}) σ(xW1+b1)(xW2+b2)

SwiGLUGLU做了两点改进:

  • 去掉了两个线性映射的偏执项,此时公式变成 σ ( x W 1 ) ⊗ ( x W 2 ) \sigma(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) σ(xW1)(xW2)
  • sigmoid替换为了Swish,此时公式变成 Swish β ( x W 1 ) ⊗ ( x W 2 ) \text{Swish}_{\beta}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) Swishβ(xW1)(xW2)

Swish的公式为 Swish β ( a ) = a σ ( β a ) = a 1 + e − β a \text{Swish}_{\beta}(a)=a\sigma(\beta a)=\frac{a}{1+e^{-\beta a}} Swishβ(a)=(βa)=1+eβaa,在不同的 β \beta β下该非线性激活函数的曲线如下:

可见,当 β \beta β较大时,该曲线与ReLU十分接近;当 β = 1 \beta=1 β=1时,小于0但接近0的曲线变得更光滑且非单调。

SwiGLU则选用了 β = 1 \beta=1 β=1Swish,于是我们得到SwiGLU的公式如下:
Swish ( x W 1 ) ⊗ ( x W 2 ) = x W 1 1 + e − x W 1 ⊗ x W 2 \text{Swish}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2})=\frac{\boldsymbol{xW_1}}{1+e^{-\boldsymbol{xW_1}}}\otimes\boldsymbol{xW_2} Swish(xW1)(xW2)=1+exW1xW1xW2

致谢:

本博客仅做记录使用,无任何商业用途,参考内容如下:
解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索
一文为你深度解析 LLaMA2 模型架构
Llama改进之——SwiGLU激活函数


http://www.ppmy.cn/ops/96795.html

相关文章

企业网站域名如何选择?

在数字化时代,企业网站是其在线形象的重要组成部分。域名作为网站的网络地址,不仅关系到用户的第一印象,也是企业品牌战略的关键元素。选择合适的域名对于建立企业在线身份、提高品牌知名度和吸引潜在客户至关重要。本文将探讨企业在选择网站…

Linux系统tar归档文件中解压指定文件

一、查看归档文件内容 要查看.tar归档文件中的文件列表,可以使用tar命令的-t(或--list)选项,该选项会列出归档文件中包含的所有文件,而不会实际解压它们。这里是一个基本的命令示例: tar -tf yourfile.ta…

AI绘画Stable Diffusion插件—LayerDiffusion 分层控图新突破!生成透明图片前后景图片融合,毫无违和感!

大家好,我是画画的小强 用AI绘画Stable Diffusion 生成透明图片怎么搞? 这要搁之前,我们需要生成完图片,然后放到去背景插件中调整参数去除背景!效果一般般 如果想要在一张图片上添加主体,该怎么搞&#…

【Redis】解析Redisson 限流器源码

Redisson 一、注解AOP 代码部分提取二、设置限流器的失效时间 一、注解AOP 代码部分提取 // 调用Reids工具类的rateLimiter 方法long number RedisUtils.rateLimiter(combineKey, rateType, count, time);redis 工具类 public class RedisUtils {private static final Redis…

Python生成432Hz音频

使用 numpy 来生成信号, 使用 matplotlib 可视化信号, 使用 sounddevice 播放声音。 以下生成和播放 432 Hz 的正弦波信号: import numpy as np import sounddevice as sd import matplotlib.pyplot as plt# 生成单音函数 def generate_to…

从简单到复杂,训练神经网络的秘诀

据称,开始训练神经网络非常简单。许多库和框架都以展示 30 行神奇的代码片段来解决问题而自豪,给人一种这些东西是即插即用的(错误)印象。常见的情况如下: >>> your_data # plug your awesome dataset here…

第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略_海南新境界队

关联比赛: 第二届海南大数据创新应用大赛 - 智能算法赛 第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略 首先很幸运能拿到这次初赛冠军,本着积极学习和提升自我的态度,团队成员通力合作是获胜关键,再次感谢。 赛题背景分析和理解 …

Vue 3 组合式 API 中的 nextTick 深入解析

Vue.js 是一个渐进式 JavaScript 框架,以其易学、高效和灵活的特点,成为构建交互式 Web 界面的理想选择。Vue 3 通过一系列性能提升、架构重构和改进开发体验等优点,进一步提高了 Vue.js 的优越性。在 Vue 3 中,组合式 API&#x…

第2章 C语言基础知识

第2章 C语言基础知识 1.printf()函数 在控制台输出数据,需要使用输出函数,C语言常用的输出函数为printf()。 printf()函数为格式化输出函数,其功能是按照用户指定的格式将数据输出到屏幕上。 printf(“格式控制字符串”,[输出列表]); 格式控…

基于Docker compose部署Confluence 8.3.4及设置数据持久化存储的总结

基于Docker compose部署Confluence 8.3.4及设置数据持久化存储的总结 一、环境信息二、安装部署三、向导 介绍如何基于Docker、Docker Compose的方式安装部署Confluence 8.3.4,并且设置数据的持久化存储。 一、环境信息 操作系统:CentOS 7.9 Docker Ver…

Redis系列之事务

概述 Redis事务提供一种将多个命令打包,然后一次性、按顺序地执行的机制,在事务执行的期间不会主动中断,服务器在执行完事务中的所有命令之后,才会继续处理其他客户端的其他命令。 三个重要的保证: 批量操作在发送E…

产品分析 | 便利蜂

​产品信息 产品名称:便利蜂 Slogan:小小的幸福 在你身边 版本号:V1.11.3 大小:23.6M 体验环境:Android6.0.1 品牌概述 便利蜂成立于2016年12月,算是起步较早的企业了,17年2月就开了第一家…

基于单片机的 GPS 信息处理系统

摘 要 : 介绍一种基于单片机的 GPS 信息处理系统 。 以 AT MEL 公司的单片机 AT 89C2051 作为核心控制器件 , LCD和键盘作为人机界面, 通过串行口接收 GPS 接收机输出的 NMEA 全球定位系统 ( Global Positioning System, GPS) 是美国从 20 世纪 70 年代开始研制…

基于微信小程序的课堂考勤系统的设计与实现(论文+源码)_kaic

基于微信小程序的课堂考勤系统的设计与实现 摘 要 在高校教育普及的今天,学生人数日益增多,为保证课堂质量,教师多要在课前进行考勤。因此本设计提出基于微信小程序的课堂考勤系统,增加了定位功能,避免了“假打卡”…

配置EIGRP命名模式

背景資訊 傳統的配置EIGRP的方法要求在介面和EIGRP配置模式下配置各種引數。為了配置EIGRP IPV4和IPv6,需要配置單獨的EIGRP例項。傳統EIGRP在IPv6 EIGRP實施中不支援虛擬路由和轉發(VRF)。 對於命名模式EIGRP,所有配置都在EIGRP配置下的單個位置進行配…

成为Python砖家(3): 何时产生字节码 .pyc 文件

好奇:.pyc和 __pycache__是啥? 你是否好奇,在某些 Python 工程中,当执行了 xxx.py脚本后,多出了 __pycache__目录?这个目录下存放的是一些 .pyc结尾的文件。 这些文件,叫做 python bytecode。 …

JSON.stringify 和 JSON.parse

JSON.stringify 是一个将 JavaScript 对象转换为 JSON 字符串的方法,它有三个参数: JSON.stringify(value, replacer, space) 参数详解 value (必需): 这是你想要转换为 JSON 字符串的 JavaScript 对象或数组。例如:…

基于web框架的协同过滤的美食推荐系统【数据爬虫、管理系统、数据可更新、样式可调整】

文章目录 有需要本项目的代码或文档以及全部资源,或者部署调试可以私信博主项目介绍研究背景研究的目的与意义协同过滤算法基于用户的协同过滤算法定义基于物品的协同过滤算法的定义 数据库设计db_food(美食信息表)db_collect(美食…

Leaflet中Marker加载设置SVG

Leaflet Marker 加载设置 SVG 1. SVG加载1.1. svg 路径设置问题1.2. svg 路径设置 1. SVG加载 1.1. svg 路径设置问题 Leafett 中的 Marker 是以 IMG 标签实现的。在 vue 项目中,如果直接用一下写法: img.src "./svg/404.svg" 项目运行后&a…

聊聊场景及场景测试

在我们进行测试过程中,有一种黑盒测试叫场景测试,我们完全是从用户的角度去理解系统,从而可以挖掘用户的隐含需求。 场景是指用户会使用这个系统来完成预定目标的所有情况的集合。 场景本身也代表了用户的需求,所以我们可以认为…