Llama网络结构介绍

server/2024/9/18 12:52:06/ 标签: llama

LLaMA现在已经是开源社区里炙手可热的模型了,但是原文中仅仅介绍了其和标准Transformer的差别,并没有一个全局的模型介绍。因此打算写篇文章,争取让读者不参考任何其他资料把LLaMA的模型搞懂。

结构

如图所示为LLaMA的示意图,由Attention和MLP层堆叠而成
在这里插入图片描述
LLaMA模型主要由Attention和MLP层堆叠而成,具有以下特点:
1、前置的RMSNorm:RMSNorm是一种归一化技术,用于稳定模型的训练过程,提高模型的收敛速度。
2、Q、K上的RoPE旋转式位置编码:位置编码用于捕捉序列中的位置信息,RoPE旋转式位置编码能够有效地处理长序列,提高模型的性能。
3、Causal mask:该机制保证每个位置只能看到前面的tokens,确保了模型的自回归性质。
4、使用了Group Query Attention:通过使用分组查询注意力(GQA),LLaMA能够在保持性能的同时,降低模型的计算复杂度,提高推理速度。
5、MLP表达式:down(up(x) * SILU(gate(x))),其中down, up, gate都是线性层
LLaMA各个不同大小的结构设置如下表所示。其中最大的65B的LLaMA用了2048张80GB的A100,batch size为4百万,训练一次需要21天。

Group Query Attention(V2 only)

自回归模型生成回答时,需要前面生成的KV缓存起来,来加速计算。多头注意力机制(MHA)需要的缓存量很大,Multi-Query Attention指出多个头之间可以共享KV对。Group Query Attention没有像MQA一样极端,将query分组,组内共享KV,效果接近MHA,速度上与MQA可比较。p.s. 这个技术falcon已经用上了,当时falcon说自己用的是multi query attention,因为当group=1时,GQA和MQA是等价的。falcon支持设置不同的G。
在这里插入图片描述

RMSNorm

这是在BERT、GPT等模型中广泛使用的LayerNorm:
在这里插入图片描述
RMSNorm(root mean square)发现LayerNorm的中心偏移没什么用(减去均值等操作)。将其去掉之后,效果几乎不变,但是速度提升了40%。最终公式为:
在这里插入图片描述
注意除了没有减均值,加偏置以外,分母上求的RMS而不是方差。

LLaMA在 Attention Layer和MLP的输入上使用了RMSNorm,相比在输出上使用,训练会更加稳定。

SwiGLU

LLaMA没有使用ReLU,而是使用了SwiGLU,有时也被称为SiLU。公式为:
,效果类似平滑版的ReLU:
在这里插入图片描述

RoPE

LLaMA使用了Rotary Position Embedding。对于Q的第m个位置向量q,通过以下方法注入位置编码:
在这里插入图片描述

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000):super().__init__()theta = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))t = torch.arange(max_position_mbeddings)freqs = torch.einsum("i,j->ij", t, theta)emb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos())self.register_buffer("sin_cached", emb.sin())def forward(self, seq_len=None):return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]# 在LlamaAttention通过以下命令调用:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)

以下代码将q沿着最后一个维度劈成两半,将后一半乘-1,然后连接在第一半之前,就得到了上式第三项。

# 在接下来的apply_rotary_pos_emb函数里调用def rotate_half(x):x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)

最后通过以下代码得到结合了位置编码的Q,K(K和Q使用同样的方式进行位置编码)。

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):q_embed = (q * cos[position_ids]) + (rotate_half(q) * sin[position_ids])k_embed = (k * cos[position_ids]) + (rotate_half(k) * sin[position_ids])return q_embed, k_embed# 在LlamaAttention中通过以下命令调用:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

绝对位置编码的优点是计算速度快等,缺点是拓展长度比较麻烦,且绝对位置并没有什么实际意义。而相对位置编码对学习token之间的关系很有意义,比如距离的很远的两个token之间的关联大概率很小,使用相对位置编码往往能够获得更好的效果。此外拓展长度也更容易,因为不论context size多长,只需关注最长距离以内的输入即可。相对位置编码的缺点是没有绝对位置编码计算速度快。

当我们计算Attention时,RoPE可以变成相对位置编码。
在这里插入图片描述
从上面这个公式可以看出,q和k的attention依赖相对距离m-n。因此RoPE为q、k注入的绝对位置编码,计算得到的attention,却变成了相对位置编码。妙的很,我这里为了不参考其他文章就很容易搞懂LLaMA的结构,简化了很多东西,推荐大家看一看RoPE原作者苏剑林的博客了解更多信息。

本文只关注LLaMA缺失的模型结构方面的介绍,对于文章的翻译可以参考其他的文章,
例如:靳伟,LLaMA大模型是如何炼成的,
其他参考文章:https://zhuanlan.zhihu.com/p/636784644
原文:https://arxiv.org/pdf/2302.13971.pdf。
文中参考的代码是huggingface的transformers库实现的版本,并不是Meta官方的代码。
备注说明:受笔者水平限制,如果哪里讲的不对,或者不够清晰易懂,欢迎在评论区与我交流。


http://www.ppmy.cn/server/18824.html

相关文章

北京车展“第一枪”:长安汽车发布全球首款量产可变新汽车

4月25日,万众瞩目的2024北京国际汽车展览会在中国国际展览中心如期而至。作为中国乃至全球汽车行业的盛宴,本次车展也吸引了无数业内人士的高度关注。 此次北京车展以“新时代 新汽车”为主题,汇聚了1500余家主流车企及零部件制造商&#xff…

发那科FANUC机器人R-2000iB平衡缸维修攻略

在发那科机器人中,平衡缸扮演着稳定机械臂运动的关键角色。它通过内部的压力调节来平衡负载,保证机器人的精准定位和平稳操作。一旦出现法兰克机械手平衡缸故障或损坏,机器人的性能可能会大打折扣,因此及时且正确的FANUC机械手平衡…

12 c++版本的坦克大战

前言 呵呵 这大概是 大学里面的 c 贪吃蛇了吧 有一些 面向对象的理解, 但是不多 这里 具体的实现 就不赘述, 仅仅是 发一下代码 以及 具体的使用 坦克大战 #include<iostream> #include<windows.h> #include<conio.h> #include<ctime> #include…

计算机网络原原理学习资料分享笔记---第二章/第七节/第八节(为有梦想的自己加油!)

第七节 P 2 P应用 第七节 P 2 P应用 知识点 1 P 2 P 第七节 P 2 P应用 知识点 1 P 2 P 谢谢 第八节 Socket编程基础 第八节 Socket编程基础 第八节 Socket编程基础 第八节 Socket编程基础 知识点 1 Socket基本概念 第八节 Socket编程基础 第八节 Socket编程基础 4 、TCP…

Sui主网升级至V1.23.1版本

其他升级要点如下所示&#xff1a; #17126 协议&#xff1a;Deepbook的更改将被还原。 #16673 开发者可能会看到更多编译器诊断&#xff0c;因为选择的解析错误不再阻止编译&#xff0c;并且编译器的诊断会到达后续编译阶段&#xff0c;其中可能会生成额外的诊断。 #16966…

基于Anaconda搭建Pytorch环境

准备虚拟环境 创建一个虚拟创建&#xff1a; conda create --name nlp python3.11.7激活虚拟环境&#xff1a; conda activate nlp安装pytorh 首先&#xff0c;可以通过任务管理器查看你的电脑是否支持GPU&#xff1a; 如果支持&#xff0c;到网址&#xff1a;https://py…

【c++每天一题】 快速幂

快速冥 描述 输入 b,p,k 的值&#xff0c;求 bp mod k的值。其中 b,p,k 为长整型数。 输入描述 输入 b,p,k 的值。 输出描述 求 bp mod k 的值。 样例输入 1 2 10 9 样例输出 1 7 代码&#xff1a; #include<bits/stdc.h> using namespace std; //求a的b次方%k的结果 …

深入了解排序算法:数据结构中的排序技术

目录 前言 1. 排序的基本概念 2. 插入排序 2.1插入排序的步骤&#xff1a; 2.2插入排序的示例&#xff08;升序排序&#xff09;&#xff1a; 2..3特点和复杂性分析&#xff1a; 2.4实现 2.5复杂性分析 2.6实际应用 3. 交换排序 3.1交换排序&#xff08;冒泡排序&…

力扣1518. 换水问题

题目链接 力扣1518. 换水问题 简单方法(模拟) 思路 对换水进行模拟&#xff0c;每次喝完 n u m E x c h a n g e numExchange numExchange 瓶水后就去换一瓶水&#xff0c;直到不能再兑换为止&#xff0c;也就是剩余水的数量小于 n u m E x c h a n g e numExchange numE…

智能穿戴终端设备安卓主板方案_MTK平台智能手表PCBA定制开发

新移科技智能手表方案兼容WiFi、BLE、2~5G等多种通信能力。支持多个功能模块&#xff0c;包括&#xff1a;通话、计步、定位、睡眠监测、心率监测、血氧监测等。智能手表通过滑动与功能性按键提供高度直观的体验感受&#xff0c;从腕间即可掌控日常生活。形态支持定制包括&…

【深度学习(1)】研0和研1如何上手深度学习及定方向

深度学习&#xff08;1&#xff09; 基础部分书籍鱼书 (理论部分) 视频课程我是土堆&#xff08;代码部分&#xff09; 提升部分李沐的动手学深度学习李沐老师的书 定方向网站&#xff1a; paperwithcode谷歌学术找论文 基础部分 书籍 鱼书 (理论部分) 适合入门&#xff0c;…

Java中的ArrayList

ArrayList<E>的特点 可调整大小的数组实现 <E>:是一种数据类型 ArrayList的构造方法 ArrayList list new ArrayList();创建一个空的集合对象 package dayhou40.day45; ​ import java.util.ArrayList; ​ public class Arraylisttest {public static void ma…

spring常用注解(五)lombok库

一、介绍&#xff1a; 1、简介&#xff1a; Lombok是一个作用于编辑器和构建工具的 Java 库&#xff0c;可以对编写的 Java 代码进行增强&#xff0c;比如说不用再写实体类的 getter 方法&#xff0c;equals 方法而是自动生成&#xff0c;自动生成日志输出变量等等&#xff0…

头歌平台——大数据技术——上机

有问题自行解决&#xff0c;本文档仅用于记录本人课程学习的过程 大数据上机 请先阅读注意事项 md文档用户可以按住ctrl 鼠标左键跳转至注意事项 ,Pdf用户请直接点击 文章目录 大数据上机 请先阅读注意事项[toc]**注意事项 此项必看**大数据技术概述大数据应用 Linux 系统的安…

Agent思维过程样例

关键概念: 九月销售额 概念拆解: - 数据来源: 需要从提供的数据文件中查询 - 数据格式和具体文件: 未知&#xff0c;需首先确认 反思: - 还未开始查询&#xff0c;因此尚无取得的要素/概念。 - 需要首先确定数据存放的具体文件及其格式。 思考: A. 首先需要确定存放数据的…

Java设计模式 _创建型模式_单例模式(懒汉式,饿汉式)

一、单例模式 1、单例模式&#xff08;Singleton Pattern&#xff09;是一种创建对象的设计模式。一个类负责创建自己的对象&#xff0c;同时确保只有1个对象被创建&#xff0c;这个类提供了一种访问其唯一的对象的方式&#xff0c;不需要在实例化该类的对象。从而保证了这个类…

【JVM】从i++到JVM栈帧

【JVM】从i到JVM栈帧 本篇博客将用两个代码例子&#xff0c;简单认识一下JVM与栈帧结构以及其作用 从i与i说起 先不急着看i和i&#xff0c;我们来看看JVM虚拟机&#xff08;请看VCR.JPG&#xff09; 我们初学JAVA的时候一定都听到过JAVA“跨平台”的特性&#xff0c;也就是…

c++图论基础(2)

目录 图的存储方式&#xff1a; 邻接矩阵&#xff1a; 代码实现&#xff1a; 邻接表&#xff1a; 代码实现&#xff1a; 邻接矩阵邻接表对比&#xff1a; 带权图&#xff1a; 邻接矩阵存储&#xff1a; 邻接表存储(代码实现)&#xff1a; 图的存储方式&#xff1a; 邻…

如果把软路由的网段更换成169.254.0.0/16会咋样?

前言 这几天有小伙伴在折腾软路由系统&#xff0c;然后问题就来了。 他咨询的是&#xff1a;为啥电脑连接软路由之后&#xff0c;无法访问软路由的管理页&#xff1f; 嗯。。。确实不是什么大事。但不注意看&#xff0c;还以为软路由没有正常获取到ip。 熟悉网络的小伙伴们都…

C语言读数据+遍历行数程序|Visual studio 2022

读数据遍历行数程序 记录一个度数遍历行数的程序 FILE* file2; int row2 0; file2 fopen("D://sins_mat2.txt", "r"); // file1 fopen("D://ga_mat2.txt", "r"); if (file2 NULL) {printf("open file1 failed.\n");re…