LLM推理优化笔记1:KV cache、Grouped-query attention等

news/2024/9/11 3:47:27/ 标签: 论文阅读, 笔记, LLM推理

KV cache

对于decoder-only 模型比如现在如火如荼的大模型,其在生成内容的过程中,为了避免冗余计算,会将Transformer里的self-attention的K和V矩阵给缓存起来,这个过程即为KV cache。

在这里插入图片描述

decoder-only模型的生成过程是自回归的(auto-regressive),生成过程中先根据输入生成下一个token,再将生成的token与输入一起生成下一个token,重复这个过程直到遇到停止符号或者达到限定的输出token个数。(gif图来自illustrated-gpt2)
在这里插入图片描述

因为decoder-only模型的生成过程是自回归的,并且decoder的self-attention是causal的,即每一个token的attention计算只与其前面的tokens有关,所以我们每生成一个token时都重复计算了前面出现过的token的attention。为了节省计算量,可以将已经计算过的token的attention矩阵存储下来,每生成下一个token时直接使用存储好的attention矩阵并将新计算的token attention存储起来。(下面图片来自博客,不考虑softmax和scale示意对比KV cache使用)

在这里插入图片描述

在每一步计算时,只需要使用到上一步计算过的K和V矩阵,所以KV cache只会缓存K和V。当然缓存的代价就是需要额外的显存存储:

  • 每缓存一个token,其需要的空间为 2 * precision_in_bytes * head_dim * n_heads * n_layers(式中2是因为缓存K和V两个矩阵,precision_in_bytes是token的存储精度占用字节大小,head_dim是attention的head维度,n_head是attention的head个数,n_layers是transformer的层个数)。
  • 对于16-bit精度的模型以最大上下文长度max_context_length进行批量推理要求的缓存大小2 * 2 * head_dim * n_heads * n_layers * max_context_length * batch_size,比如Llama-2-13B模型对应最大上下文窗口为4096,batch大小为8时要求的缓存显存最多高达25GB左右。

transformers包生成时默认使用KV cache(use_cache=True),我们可以用如下代码去测试一下使用了KV cache以及不使用时的性能差异。

## 代码来自 https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizerdevice = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)for use_cache in (True, False):times = []for _ in range(10):  # measuring 10 generationsstart = time.time()model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)times.append(time.time() - start)print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

Multi-query attention 和Grouped-query attention

Multi-query attention

Multi-query attention(MQA)出自2019年11月的论文《Fast Transformer Decoding: One Write-Head is All You Need》,它让multi-head attention里的多个head共享K和V矩阵,并做试验验了修改之后模型的性能下降不明显,但因为减少了参数,推理时KV cache占用的存储和读取时间都会少很多。

Grouped-query attention

在这里插入图片描述

Grouped-query attention(GQA)出自2023年5月的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》, 如上图所示,它的共享K和V矩阵介于Multi-query attention(MQA)和Multi-head attention(MHA)之间,通过实验证明GQA可达到类似MQA的速度以及MHA的性能。

Grouped-query attention将query heads划分为G个groups,每一组query heads共享一个key head和value head,将 G Q A − G GQA_{-G} GQAG 记为有G个groups的grouped-query attention,则 G Q A − 1 GQA_{-1} GQA1为Multi-query attention, G Q A − H GQA_{-H} GQAH则等价于Multi-head attention。

论文还提出了一个将Multi-head attention模型转变MQA或GQA模型的方法,其分为两步:

  • 将MHA模型的checkpoint转变成MQA或GQA模型,使用如下图示意的mean pooling将多个K和V矩阵变成单个矩阵(论文做了试验比较选取第一个head、随机初始化、mean pooling,mean pooling的效果是最好的)。
  • 使用少量比例(5%左右)的预训练数据来继续预训练使模型适应新结构。

在这里插入图片描述

关于GQA的组个数选取,论文做了消融实验后对于总head个数为64时G选取的是8,而在Llama2-70B模型也是8(总heads数也为64)。

在这里插入图片描述

实现

不考虑性能的代码示意如下:

from dataclasses import dataclass
import math
import torch
import torch.nn as nn 
from torch.nn import functional as F@dataclass
class GPTConfig:block_size: int = 1024 # max sequence lengthvocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> tokenn_layer: int = 12 # number of layersn_head: int = 12 # number of headsn_embd: int = 768 # embedding dimensionn_kv_heads: int = 12 # grouped-query的group个数def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:"""Perform repeat of kv heads along a particular dimension.hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)n_rep: amount of repetitions of kv_n_headsUnlike torch.repeat_interleave, this function avoids allocating new memory.from https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py#L47llama2里的写法差不多https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py#L164C1-L165C1"""if n_rep == 1:return hidden(b, s, kv_n_heads, d) = hidden.shapehidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)return hidden.reshape(b, s, kv_n_heads * n_rep, d)## adapt from https://github.com/karpathy/nanoGPT/blob/master/model.py
class MultiHeadAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embd# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split(self.n_embd, dim=2)k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y### multi-query
class MultiQueryAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_embd//config.n_head)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embd# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split([self.n_embd, self.n_embd//self.n_head, self.n_embd//self.n_head], dim=2)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)k = repeat_kv(k.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)v = repeat_kv(v.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y### grouped-query attention
class GroupedQueryAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_kv_heads*config.n_embd//config.n_head)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embdself.n_kv_heads = config.n_kv_heads# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split([self.n_embd, self.n_kv_heads*self.n_embd//self.n_head, self.n_kv_heads*self.n_embd//self.n_head], dim=2)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)k = repeat_kv(k.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)v = repeat_kv(v.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y

Sliding Window Attention

Mistral 7B使用Sliding Window Attention(SWA)来减少KV cache的内存占用,每次计算attention时,只考虑固定窗口大小W内的信息。对于位置i的隐状态,只会考虑在其前面i-W到i的窗口内的隐状态信息,如下图示意所示,所以对于在第k层的位置i来说,最多可以访问到 W × k W\times k W×k个tokens。在Mistral 7B里,W=4096,层数为32,所以理论上的attention范围近似为131K。

在这里插入图片描述

因为使用固定attention窗口,所以Mistral 7B使用滚动(rolling) buffer cache, cache大小固定为W,在时刻t的K和V存储在cache的第i mod W个位置,也就是说如果位置i比W大,cache中原先存储的值会被覆盖掉。下图是W=3时的示意。
在这里插入图片描述

参考资料

  1. 看图学KV Cache

  2. Transformer Inference Arithmetic

  3. Transformers KV Caching Explained(其gif动画有助于加深理解)

  4. KV caching内存增长

  5. KV cache 是chatbot 规模化的一大工程挑战

  6. Techniques for KV Cache Optimization in Large Language Models

  7. KV cache quantization

  8. Inference Optimization


http://www.ppmy.cn/news/1475436.html

相关文章

c++课后作业

把字符串转换为整数 int main() {char pn[21];cout << "请输入一个由数字组成的字符串&#xff1a; ";cin >> pn;int last 0;int res[10];int j strlen(pn);int idx 2;cout << "请选择&#xff08;2-二进制&#xff0c;10-十进制&#xf…

项目/代码规范与Apifox介绍使用

目录 目录 一、项目规范&#xff1a; &#xff08;一&#xff09;项目结构&#xff1a; &#xff08;二&#xff09;传送的数据对象体 二、代码规范&#xff1a; &#xff08;一&#xff09;数据库命名规范&#xff1a; &#xff08;二&#xff09;注释规范&#xff1a; …

JavaSE 面向对象程序设计进阶 IO流练习 字节缓冲流 字符缓冲流 底层原理

目录 字节缓冲流 字节缓冲流底层原理 字符缓冲流 字节缓冲流 刚刚学习的四个流是基本流 对四个基本流进行封装&#xff0c;添加了新的功能&#xff0c;叫做缓冲流 底层自带长度为8192的缓冲区 import java.io.*;public class Main {public static void main(String[] args) …

xcode中对项目或者文件文件夹重命名操作

提起揭秘答案&#xff1a;选中文件后&#xff0c;按下回车键就可以了 如果在项目中对新建的文件夹或者文件名称不满意或者输入错误了&#xff0c;想要修改一下名称该怎么办&#xff1f;如果是在文件或文件夹上右键是没有rename选项的&#xff1a; 其实想要重命名&#xff0c;很…

红日靶场----(三)1.漏洞利用

上期已经信息收集阶段已经完成&#xff0c;接下来是漏洞利用。 靶场思路 通过信息收集得到两个吧靶场的思路 1、http://192.168.195.33/phpmyadmin/&#xff08;数据库的管理界面&#xff09; root/root 2、http://192.168.195.33/yxcms/index.php?radmin/index/login&am…

自定义波形图View,LayoutInflater动态加载控件保存为本地图片

效果图&#xff1a; 页面布局&#xff1a; <?xml version"1.0" encoding"utf-8"?><LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:tools"http://schemas.android.com/tools"android:la…

AirPods Pro新功能前瞻:iOS 18的五大创新亮点

随着科技的不断进步&#xff0c;苹果公司一直在探索如何通过创新提升用户体验。iOS 18的推出&#xff0c;不仅仅是iPhone的一次系统更新&#xff0c;更是苹果生态链中重要一环——AirPods Pro的一次重大升级。 据悉&#xff0c;iOS 18将为AirPods Pro带来五项新功能&#xff0…

电子签章 签到 互动 打卡 创意印章 支持小程序 H5 App

电子签章 签到 互动 打卡 创意印章 支持小程序 H5 App 定制化

iOS 开发中不常见的专业术语

乐此不疲地把简单的问题复杂化&#xff0c;并把这种XX行为叫作专业 APM 在 iOS 开发中&#xff0c;APM 代表 Application Performance Management&#xff08;应用性能管理&#xff09;。APM 是一套监控和管理应用程序性能的工具和技术&#xff0c;旨在确保应用程序运行平稳、…

Django 新增数据 save()方法

1&#xff0c;添加模型 Test/app11/models.py from django.db import modelsclass Book(models.Model):title models.CharField(max_length100)author models.CharField(max_length100)publication_date models.DateField()price models.DecimalField(max_digits5, decim…

常见的网络安全设备

一、防火墙 防火墙的核心任务&#xff1a;防护和控制&#xff0c;防火墙通过安全策略识别流量并做出相应的动作。 防火墙的安全策略在进行匹配时&#xff0c;自上而下逐一匹配&#xff0c;匹配成功则不向下进行匹配&#xff0c;末尾隐含拒绝所有规则。 1.包过滤防火墙 工作范围…

【C++深度探索】全面解析多态性机制(二)

&#x1f525; 个人主页&#xff1a;大耳朵土土垚 &#x1f525; 所属专栏&#xff1a;C从入门至进阶 这里将会不定期更新有关C/C的内容&#xff0c;欢迎大家点赞&#xff0c;收藏&#xff0c;评论&#x1f973;&#x1f973;&#x1f389;&#x1f389;&#x1f389; 前言 我…

Elasticsearch 理解相关性评分(TF-IDF、BM25等)

在Elasticsearch中&#xff0c;相关性评分是搜索功能的核心&#xff0c;它决定了搜索结果的质量和排序。了解Elasticsearch是如何计算相关性评分的&#xff0c;特别是TF-IDF和BM25算法&#xff0c;对于优化搜索性能和结果至关重要。本文将深入探讨这两种算法及其在Elasticsearc…

vue3 ts 报错:无法找到模块“../views/index/Home.vue”的声明文件

解决办法&#xff1a; env.d.ts 新增代码片段&#xff1a; declare module "*.vue" {import type { DefineComponent } from "vue";// eslint-disable-next-line typescript-eslint/no-explicit-any, typescript-eslint/ban-typesconst component: Define…

C#面:阐述控制反转是什么?

控制反转&#xff08;Inversion of Control&#xff0c;缩写为IoC&#xff09;&#xff0c;是⾯向对象编程中的⼀种设计原则&#xff0c;可以⽤来减低计算机代码之间的耦合度。其中最常⻅的⽅式叫做依赖注⼊&#xff08;Dependency Injection&#xff0c;简称DI&#xff09;&am…

深入解析C#中的Stopwatch类:精准计时的艺术

目录 引言 了解Stopwatch类 创建与使用Stopwatch 使用多个Stopwatch实例 性能分析与优化 结论 后记 引言 在软件开发中&#xff0c;性能分析是不可或缺的一环&#xff0c;它帮助我们识别瓶颈、优化代码&#xff0c;确保应用程序的高效运行。C#中的Stopwatch类便是开发者…

数据湖仓一体(六)安装flink

上传安装包到/opt/software目录并解压 [bigdatanode106 software]$ tar -zxvf flink-1.17.2-bin-scala_2.12.tgz -C /opt/services/ 重命名文件 [bigdatanode106 services]$ mv flink-1.17.2-bin-scala_2.12 flink-1.17.2 配置环境变量 [bigdatanode106 ~]$ sudo vim /etc…

Elasticsearch:Node.js ECS 日志记录 - Morgan

这是之前系列文章&#xff1a; Elasticsearch&#xff1a;Node.js ECS 日志记录 - Pino Elasticsearch&#xff1a;Node.js ECS 日志记录 - Winston 中的第三篇文章。在今天的文章中&#xff0c;我将描述如何使用 Morgan 包针对 Node.js 应用进行日子记录。此 Morgan Node.j…

threejs

1.场景清空&#xff0c;释放内容 // 假设你已经有一个Three.js的场景对象scene// 函数&#xff1a;清空场景中的所有对象 function clearScene(scene) {while(scene.children.length > 0){const object scene.children[0];if(object.isMesh) {// 如果有几何体和材质&#…

2024年上半年信息系统项目管理师——综合知识真题题目及答案(第1批次)(2)

2024年上半年信息系统项目管理师 ——综合知识真题题目及答案&#xff08;第1批次&#xff09;&#xff08;2&#xff09; 第21题&#xff1a;在一个大型信息系统项目中&#xff0c;项目经理发现尽管已经建立了沟通机制&#xff0c;但团队间的沟通依然不畅&#xff0c;项目风险…