LLM推理优化笔记1:KV cache、Grouped-query attention等

news/2024/9/11 3:47:27/ 标签: 论文阅读, 笔记, LLM推理

KV cache

对于decoder-only 模型比如现在如火如荼的大模型,其在生成内容的过程中,为了避免冗余计算,会将Transformer里的self-attention的K和V矩阵给缓存起来,这个过程即为KV cache。



因为decoder-only模型的生成过程是自回归的,并且decoder的self-attention是causal的,即每一个token的attention计算只与其前面的tokens有关,所以我们每生成一个token时都重复计算了前面出现过的token的attention。为了节省计算量,可以将已经计算过的token的attention矩阵存储下来,每生成下一个token时直接使用存储好的attention矩阵并将新计算的token attention存储起来。(下面图片来自博客,不考虑softmax和scale示意对比KV cache使用)


在每一步计算时,只需要使用到上一步计算过的K和V矩阵,所以KV cache只会缓存K和V。当然缓存的代价就是需要额外的显存存储:

  • 每缓存一个token,其需要的空间为 2 * precision_in_bytes * head_dim * n_heads * n_layers(式中2是因为缓存K和V两个矩阵,precision_in_bytes是token的存储精度占用字节大小,head_dim是attention的head维度,n_head是attention的head个数,n_layers是transformer的层个数)。
  • 对于16-bit精度的模型以最大上下文长度max_context_length进行批量推理要求的缓存大小2 * 2 * head_dim * n_heads * n_layers * max_context_length * batch_size,比如Llama-2-13B模型对应最大上下文窗口为4096,batch大小为8时要求的缓存显存最多高达25GB左右。

transformers包生成时默认使用KV cache(use_cache=True),我们可以用如下代码去测试一下使用了KV cache以及不使用时的性能差异。

## 代码来自 https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizerdevice = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)for use_cache in (True, False):times = []for _ in range(10):  # measuring 10 generationsstart = time.time()model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)times.append(time.time() - start)print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

Multi-query attention 和Grouped-query attention

Multi-query attention

Multi-query attention(MQA)出自2019年11月的论文《Fast Transformer Decoding: One Write-Head is All You Need》,它让multi-head attention里的多个head共享K和V矩阵,并做试验验了修改之后模型的性能下降不明显,但因为减少了参数,推理时KV cache占用的存储和读取时间都会少很多。

Grouped-query attention


Grouped-query attention(GQA)出自2023年5月的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》, 如上图所示,它的共享K和V矩阵介于Multi-query attention(MQA)和Multi-head attention(MHA)之间,通过实验证明GQA可达到类似MQA的速度以及MHA的性能。

Grouped-query attention将query heads划分为G个groups,每一组query heads共享一个key head和value head,将 G Q A − G GQA_{-G} GQAG 记为有G个groups的grouped-query attention,则 G Q A − 1 GQA_{-1} GQA1为Multi-query attention, G Q A − H GQA_{-H} GQAH则等价于Multi-head attention。

论文还提出了一个将Multi-head attention模型转变MQA或GQA模型的方法,其分为两步:

  • 将MHA模型的checkpoint转变成MQA或GQA模型,使用如下图示意的mean pooling将多个K和V矩阵变成单个矩阵(论文做了试验比较选取第一个head、随机初始化、mean pooling,mean pooling的效果是最好的)。
  • 使用少量比例(5%左右)的预训练数据来继续预训练使模型适应新结构。






from dataclasses import dataclass
import math
import torch
import torch.nn as nn 
from torch.nn import functional as F@dataclass
class GPTConfig:block_size: int = 1024 # max sequence lengthvocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> tokenn_layer: int = 12 # number of layersn_head: int = 12 # number of headsn_embd: int = 768 # embedding dimensionn_kv_heads: int = 12 # grouped-query的group个数def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:"""Perform repeat of kv heads along a particular dimension.hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)n_rep: amount of repetitions of kv_n_headsUnlike torch.repeat_interleave, this function avoids allocating new memory.from https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py#L47llama2里的写法差不多https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py#L164C1-L165C1"""if n_rep == 1:return hidden(b, s, kv_n_heads, d) = hidden.shapehidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)return hidden.reshape(b, s, kv_n_heads * n_rep, d)## adapt from https://github.com/karpathy/nanoGPT/blob/master/model.py
class MultiHeadAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embd# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split(self.n_embd, dim=2)k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y### multi-query
class MultiQueryAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_embd//config.n_head)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embd# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split([self.n_embd, self.n_embd//self.n_head, self.n_embd//self.n_head], dim=2)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)k = repeat_kv(k.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)v = repeat_kv(v.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y### grouped-query attention
class GroupedQueryAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_kv_heads*config.n_embd//config.n_head)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)# regularizationself.n_head = config.n_headself.n_embd = config.n_embdself.n_kv_heads = config.n_kv_heads# not really a 'bias', more of a mask, but following the OpenAI/HF naming thoughself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dim# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformerqkv = self.c_attn(x)q, k, v = qkv.split([self.n_embd, self.n_kv_heads*self.n_embd//self.n_head, self.n_kv_heads*self.n_embd//self.n_head], dim=2)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)k = repeat_kv(k.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)v = repeat_kv(v.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)# attention (materializes the large (T,T) matrix for all the queries and keys)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.c_proj(y)return y

Sliding Window Attention

Mistral 7B使用Sliding Window Attention(SWA)来减少KV cache的内存占用,每次计算attention时,只考虑固定窗口大小W内的信息。对于位置i的隐状态,只会考虑在其前面i-W到i的窗口内的隐状态信息,如下图示意所示,所以对于在第k层的位置i来说,最多可以访问到 W × k W\times k W×k个tokens。在Mistral 7B里,W=4096,层数为32,所以理论上的attention范围近似为131K。


因为使用固定attention窗口,所以Mistral 7B使用滚动(rolling) buffer cache, cache大小固定为W,在时刻t的K和V存储在cache的第i mod W个位置,也就是说如果位置i比W大,cache中原先存储的值会被覆盖掉。下图是W=3时的示意。


  1. 看图学KV Cache

  2. Transformer Inference Arithmetic

  3. Transformers KV Caching Explained(其gif动画有助于加深理解)

  4. KV caching内存增长

  5. KV cache 是chatbot 规模化的一大工程挑战

  6. Techniques for KV Cache Optimization in Large Language Models

  7. KV cache quantization

  8. Inference Optimization




