Llama网络结构介绍

devtools/2024/9/23 9:30:17/

LLaMA现在已经是开源社区里炙手可热的模型了,但是原文中仅仅介绍了其和标准Transformer的差别,并没有一个全局的模型介绍。因此打算写篇文章,争取让读者不参考任何其他资料把LLaMA的模型搞懂。

结构

如图所示为LLaMA的示意图,由Attention和MLP层堆叠而成
在这里插入图片描述
LLaMA模型主要由Attention和MLP层堆叠而成,具有以下特点:
1、前置的RMSNorm:RMSNorm是一种归一化技术,用于稳定模型的训练过程,提高模型的收敛速度。
2、Q、K上的RoPE旋转式位置编码:位置编码用于捕捉序列中的位置信息,RoPE旋转式位置编码能够有效地处理长序列,提高模型的性能。
3、Causal mask:该机制保证每个位置只能看到前面的tokens,确保了模型的自回归性质。
4、使用了Group Query Attention:通过使用分组查询注意力(GQA),LLaMA能够在保持性能的同时,降低模型的计算复杂度,提高推理速度。
5、MLP表达式:down(up(x) * SILU(gate(x))),其中down, up, gate都是线性层
LLaMA各个不同大小的结构设置如下表所示。其中最大的65B的LLaMA用了2048张80GB的A100,batch size为4百万,训练一次需要21天。

Group Query Attention(V2 only)

自回归模型生成回答时,需要前面生成的KV缓存起来,来加速计算。多头注意力机制(MHA)需要的缓存量很大,Multi-Query Attention指出多个头之间可以共享KV对。Group Query Attention没有像MQA一样极端,将query分组,组内共享KV,效果接近MHA,速度上与MQA可比较。p.s. 这个技术falcon已经用上了,当时falcon说自己用的是multi query attention,因为当group=1时,GQA和MQA是等价的。falcon支持设置不同的G。
在这里插入图片描述

RMSNorm

这是在BERT、GPT等模型中广泛使用的LayerNorm:
在这里插入图片描述
RMSNorm(root mean square)发现LayerNorm的中心偏移没什么用(减去均值等操作)。将其去掉之后,效果几乎不变,但是速度提升了40%。最终公式为:
在这里插入图片描述
注意除了没有减均值,加偏置以外,分母上求的RMS而不是方差。

LLaMA在 Attention Layer和MLP的输入上使用了RMSNorm,相比在输出上使用,训练会更加稳定。

SwiGLU

LLaMA没有使用ReLU,而是使用了SwiGLU,有时也被称为SiLU。公式为:
,效果类似平滑版的ReLU:
在这里插入图片描述

RoPE

LLaMA使用了Rotary Position Embedding。对于Q的第m个位置向量q,通过以下方法注入位置编码:
在这里插入图片描述

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000):super().__init__()theta = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))t = torch.arange(max_position_mbeddings)freqs = torch.einsum("i,j->ij", t, theta)emb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos())self.register_buffer("sin_cached", emb.sin())def forward(self, seq_len=None):return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]# 在LlamaAttention通过以下命令调用:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)

以下代码将q沿着最后一个维度劈成两半,将后一半乘-1,然后连接在第一半之前,就得到了上式第三项。

# 在接下来的apply_rotary_pos_emb函数里调用def rotate_half(x):x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)

最后通过以下代码得到结合了位置编码的Q,K(K和Q使用同样的方式进行位置编码)。

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):q_embed = (q * cos[position_ids]) + (rotate_half(q) * sin[position_ids])k_embed = (k * cos[position_ids]) + (rotate_half(k) * sin[position_ids])return q_embed, k_embed# 在LlamaAttention中通过以下命令调用:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

绝对位置编码的优点是计算速度快等,缺点是拓展长度比较麻烦,且绝对位置并没有什么实际意义。而相对位置编码对学习token之间的关系很有意义,比如距离的很远的两个token之间的关联大概率很小,使用相对位置编码往往能够获得更好的效果。此外拓展长度也更容易,因为不论context size多长,只需关注最长距离以内的输入即可。相对位置编码的缺点是没有绝对位置编码计算速度快。

当我们计算Attention时,RoPE可以变成相对位置编码。
在这里插入图片描述
从上面这个公式可以看出,q和k的attention依赖相对距离m-n。因此RoPE为q、k注入的绝对位置编码,计算得到的attention,却变成了相对位置编码。妙的很,我这里为了不参考其他文章就很容易搞懂LLaMA的结构,简化了很多东西,推荐大家看一看RoPE原作者苏剑林的博客了解更多信息。

本文只关注LLaMA缺失的模型结构方面的介绍,对于文章的翻译可以参考其他的文章,
例如:靳伟,LLaMA大模型是如何炼成的,
其他参考文章:https://zhuanlan.zhihu.com/p/636784644
原文:https://arxiv.org/pdf/2302.13971.pdf。
文中参考的代码是huggingface的transformers库实现的版本,并不是Meta官方的代码。
备注说明:受笔者水平限制,如果哪里讲的不对,或者不够清晰易懂,欢迎在评论区与我交流。


http://www.ppmy.cn/devtools/18128.html

相关文章

数据结构-二叉树-堆(二)

一、建堆的时间复杂度问题 1、除了向上调整建堆,我们还可以向下调整建堆。不能在根上直接开始向下调整。这里的条件就是左右子树必须都是大堆或者小堆。我们可以倒着往前走,可以从最后一个叶子开始调整。但是从叶子开始调整没有意义。所以我们可以从倒数…

C语言 switch语句

之前 我们讲了 if 和 嵌套的if分支语句 但其实 多分支语句 我们还可以用 switch 有时 switch 语句可以简化逻辑代码 switch语句也称之为开关语句,其像多路开关一样,使程序控制流程形成多个分支,根据一个表达式的不同取值,选择其…

antd级联选择器如何使用后台的数据字段替换option里面的lable和value以及children

其主要运用了antd Cascader组件的fieldNames属性 import React from react; import { Cascader } from antd;const options [{id: 1,name: 选项1,children: [{id: 11,name: 子选项1,},],},{id: 2,name: 选项2,children: [{id: 21,name: 子选项2,},],}, ];const App () > …

【解决Android Studio】Could not resolve com.android.tools.build:gradle:7.4.1

【报错信息】 所用IDE为Android studio2022 1.1 Patch 1。 使用Android Studio新创建的新工程,在build过程中报了以下错误: A problem occurred configuring root project Application. > Could not resolve all files for configuration :classpat…

Servlet、Tomcat、Control区别

1. Servlet Servlet 是一种动态网站开发技术,专门用来处理客户端的请求并生成响应。Servlet直接与Tomcat交互,处理从Tomcat传来的请求。然后生成网页或其他类型的响应发送回Tomcat,Tomcat再将这些响应返回给用户的浏览器。 2. TomCat tomc…

代码随想录算法训练营第6天 | 242. 有效的字母异位词 | 349. 两个数组的交集 | 202. 快乐数 | 1. 两数之和

242. 有效的字母异位词 题意 两个字符串中每个字符的出现次数是否一样 解 hash bool isAnagram(char* s, char* t) {int array[30];memset(array, 0, sizeof(int) * 30);for (int i 0; s[i] ! \0; i) {array[s[i] - a];}for (int i 0; t[i] ! \0; i) {array[t[i]-a]--;}…

管理系统图片登录访问

图片就是url,但是有些管理系统的图片或者文件比较机密,需要登录之后才能访问,,就需要前端进行发送图片请求的时候携带上认证token,, 返回图片的二进制,然后再渲染到页面。。 FileReader使用 ax…

网络安全实训Day16

网络空间安全实训-渗透测试 漏洞扫描 定义 扫描和探测目标范围内的主机存在哪些安全漏洞,或扫描目标范围内的那些主机存在某个指定的漏洞 漏扫工具 AWVS APPScan MSF 使用MSF扫描漏洞并利用 1.搜索需要的攻击模块 search ms17-010 2.使用攻击模块 use 模块名称…