探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

news/2024/9/25 8:32:02/

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): def  __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads# Indicates the number of heads for the queriesself.n_heads_q = args.n_heads# Indiates how many times the heads of keys and value should be repeated to match the head of the Queryself.n_rep = self.n_heads_q // self.n_kv_heads# Indicates the dimentiona of each headself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):batch_size, seq_len, _ = x.shape #(B, 1, dim)# Apply the wq, wk, wv matrices to query, key and value# (B, 1, dim) -> (B, 1, H_q * head_dim)xq = self.wq(x)# (B, 1, dim) -> (B, 1, H_kv * head_dim)xk = self.wk(x)xv = self.wv(x)# (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# Apply the rotary embeddings to the keys and values# Does not chnage the shape of the tensor# (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)# Replace the enty in the cache for this tokenself.cache_k[:batch_size, start_pos:start_pos + seq_len] = xkself.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv# Retrive all the cached keys and values so far# (B, seq_len_kv, H_kv, head_dim)keys = self.cache_k[:batch_size, 0:start_pos + seq_len]values = self.cache_v[:batch_size, 0:start_pos+seq_len] # Repeat the heads of the K and V to reach the number of heads of the querieskeys = repeat_kv(keys, self.n_rep)values = repeat_kv(values, self.n_rep)# (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)output = torch.matmul(scores, values)# (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述


http://www.ppmy.cn/news/1438931.html

相关文章

让流程图动起来

我们平时画流程,然后贴到文档,就完事了。但是过程演示的时候,如果只是一张静态图,很难吸引到听众的注意力,表达效果并不太好。常用的方法是可以用PPT进行动态演示,做PPT也是需要花一些时间,同时…

循环神经网络实例——序列预测

我们生活的世界充满了形形色色的序列数据,只要是有顺序的数据统统都可以看作是序列数据,比如文字是字符的序列,音乐是音符组成的序列,股价数据也是序列,连DNA序列也属于序列数据。循环神经网络RNN天生就具有处理序列数…

Linux的学习之路:20、进程信号(2)

摘要 本章讲一下进程信号的阻塞信号和捕捉信号和可重入函数 目录 摘要 一、阻塞信号 1、阻塞信号 2、信号集操作函数 二、捕捉信号 1、内核如何实现信号的捕捉 2、代码实演 三、可重入函数 一、阻塞信号 1、阻塞信号 实际执行信号的处理动作称为信号递达(Delivery) …

【Java】从0实现一个消息队列中间件

从0实现一个消息队列中间件 什么是消息队列需求分析核心概念核心API交换机类型持久化网络通信网络通信API 消息应答 模块划分项目创建创建核心类创建Exchange创建MSGQueue创建Binding创建Message 数据库设计配置sqlite实现创建表和数据库基本操作 实现DataBaseManager创建DataB…

Linux安装Kubernetes(k8s)详细教程

系统初始化 生产环境肯定要更高配置,虚拟机以保守的最低配置。 机器ip规格master192.168.203.111核2线程、2G内存、40G磁盘node2192.168.203.121核2线程、2G内存、40G磁盘node3192.168.203.131核2线程、2G内存、40G磁盘 修改为静态ip vi /etc/resolv.conf追加内容…

如何在Pycharm中使用Git来进行版本管理

推荐视频:git pycharm的使用 连接github_哔哩哔哩_bilibilipycharm git的使用简单介绍 最近应该不会更新技能相关视频了 准备开题, 视频播放量 13042、弹幕量 2、点赞数 208、投硬币枚数 143、收藏人数 343、转发人数 58, 视频作者 呃呃燕, 作者简介 努力入门的计算机双非研究生…

原型模式(上机考试抽题)

定义 原型模式主要解决的问题就是创建复对象,⽽这部分 对象 内容本身⽐较复杂,⽣成过程可能从库或者RPC接⼝中获取数据的耗时较⻓,因此采⽤克隆的⽅式节省时间。 上机考试抽题 从⼀部分可以上机考试的内容开始,在保证⼤家的公平…

《第二行代码》第二版学习笔记(6)——内容提供器

文章目录 一 运行时权限2.权限分类3 运行时申请权限 二、内容提供器1、 ContentResolver的基本用法2、现有的内容提供器3、创建自己的内容提供器2.1 创建内容提供器的步骤2.2 跨程序数据共享 内容提供器(Content Provider)主要用于在不同的应用程序之间实…