NLP高频面试题(十七)——什么是KV Cache

news/2025/4/2 4:55:09/

在当今火热的大语言模型领域,模型的参数动辄数十亿甚至上千亿,随着输入的上下文(token长度)增加,推理过程中的计算量和显存消耗都会显著增加。其中,KV Cache 是大模型推理过程中的一种重要优化技术。

本文将围绕 KV Cache 详细展开,帮助你深入理解这个关键技术的原理、优势以及相关的优化方案。

一、什么是 KV Cache?

KV Cache,全称为 Key-Value Cache,是在Transformer模型推理过程中,为减少重复计算、降低内存开销而设计的一种缓存机制。具体来说:

  • Transformer 模型中,每生成一个新词(token)时,都需要计算该词与前面所有词之间的注意力(attention)。
  • 注意力计算涉及 Query(Q)、Key(K) 和 Value(V) 三个张量,其中 Key 和 Value 对于已生成的 token 是不变的,只有 Query 会随每次生成而更新。
  • KV Cache 就是将这些已经计算好的 Key 和 Value 存储起来,供下一次生成 token 时直接复用,而不必重复计算。

这种缓存机制极大提升了推理效率,尤其在长序列的自回归生成场景中非常有效。

二、为什么需要 KV Cache?

以大模型的推理过程为例,我们一般分为两个阶段:

  • 预填充阶段(Prefill stage)
    模型处理整个输入序列,这个阶段高度并行化,利用 GPU 效率高。

  • 解码阶段(Decode stage)
    模型逐个生成新 token,每生成一个 token,都需要与之前所有 token 进行注意力计算,这个阶段效率较低。

在解码阶段,如果每次生成新 token 都重新计算 Key 和 Value,将产生大量冗余计算,推理效率极低。通过使用 KV Cache,模型可以避免重复计算,大大降低了内存带宽需求和计算成本,显著提升推理速度。


三、KV Cache 的实现原理(附代码示例)

KV Cache 的实现非常简单,其核心思想是:

  • 在 Transformer 模型的每个 self-attention 层存储过去的 Key 和 Value。
  • 当生成新 token 时,只需计算当前 token 对应的 Key 和 Value,并将它们追加到过去缓存的 Key 和 Value 后面。

下面是一段使用 PyTorch 实现的简化示例:

import torch# 假设 key_states 和 value_states 是当前 token 的计算结果
# past_key_value 是缓存的历史 Key 和 Valuedef update_kv_cache(past_key_value, key_states, value_states, use_cache=True):if past_key_value is not None:# 将历史的 Key 和当前的 Key 拼接key_states = torch.cat([past_key_value[0], key_states], dim=-2)value_states = torch.cat([past_key_value[1], value_states], dim=-2)# 更新缓存past_key_value = (key_states, value_states) if use_cache else Nonereturn past_key_value

这里的 past_key_value 通常是一个元组 (key_states, value_states),每个 tensor 的形状一般为 (batch_size, num_heads, seq_len, head_dim)

值得注意的是,Query 并不缓存,因为每次推理只关心最新生成的那个 token,因此每次只需用最新的 Query 向量即可。

四、优化 KV Cache:MQA 与 GQA

随着 KV Cache 应用的广泛,一些变种注意力机制诞生,例如:

(1)多查询注意力 (MQA, Multi-Query Attention)

MQA 中所有的头共享一组 Key 和 Value 矩阵,相比传统的 MHA (Multi-Head Attention),缓存的 KV 体积显著降低:

  • MHA 缓存大小:(num_heads, seq_len, head_dim)
  • MQA 缓存大小:仅需 (seq_len, head_dim)

这种方式虽然有效节约了缓存,但容易损失精度,通常需要额外的训练优化。

(2)分组查询注意力 (GQA, Grouped Query Attention)

GQA 是 MHA 和 MQA 的一种折中方案:

  • 将注意力头分成若干组,每组内头共享一份 Key 和 Value。
  • GQA-N 表示头数被分成 N 组,GQA-1 即为 MQA,GQA-头数即为传统 MHA。

GQA 能够在内存效率和模型精度之间取得平衡,像知名的 Llama2 70B 模型正是采用了 GQA。

GQA实现示例:

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:bs, slen, n_kv_heads, head_dim = x.shapeif n_rep == 1:return xreturn (x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim))

这段代码的核心作用就是将少数几个 Key 和 Value 头扩展到更多的头,实现分组共享。

五、KV Cache 的挑战和进一步优化

虽然 KV Cache 在推理阶段提供了显著的加速,但其内存占用仍然较大,尤其是在批量推理和长序列情况下。

为了解决这一挑战,目前已有多种策略:

  • 分页缓存(Paged KV Cache)
    仿照操作系统分页机制,将 KV 缓存分成固定大小的块,动态分配和释放,从而减少碎片、提高内存利用率。

  • Flash Attention
    通过重新安排注意力计算顺序,将多次内存读写优化为一次性处理,极大提高了 GPU 利用效率,降低 KV Cache 存取开销。

  • 量化(Quantization)和稀疏化(Sparsity)
    对 KV 缓存做低精度量化或稀疏表示,进一步降低内存占用。


http://www.ppmy.cn/news/1584525.html

相关文章

如何屏蔽mac电脑更新提醒,禁止系统更新

最烦mac的系统更新提醒了,过几天就是更新弹窗提醒,现在可以直接禁掉了,眼不见心不乱,不然一升级,开发环境全都不能用了,那才是最可怕的,屏蔽的方法也很简单,就是屏蔽mac系统更新的请…

太阳能台风预警宣传信号智慧杆:科技赋能防灾减灾的新标杆

在全球气候变化持续加剧、台风灾害频繁发生的大背景之下,借助科技手段提高预警效率以及保障公共安全,已然成为现代城市管理领域的关键课题。太阳能台风预警宣传信号智慧杆(以下简称 “智慧杆”)适时出现,凭借其以绿色能…

机器学习的一百个概念(3)上采样

前言 本文隶属于专栏《机器学习的一百个概念》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索&…

css基础之浮动相关学习

一、浮动基本介绍 在最初&#xff0c;浮动是用来实现文字环绕图片效果的&#xff0c;现在浮动是主流的页面布局方式之一。 效果/代码 图片环绕 代码 div {width: 600px;height: 400px;background-color: skyblue;}img {width: 200px;float: right;margin-right: 0.5em;}<…

想弄清VR和AR区别,这一篇文章就够了

一、VR 与 AR 的定义差异 VR 即虚拟现实&#xff0c;是通过计算机生成的虚拟环境&#xff0c;用户可通过佩戴设备完全沉浸其中。比如&#xff0c;虚拟现实技术通过计算机模拟产生一个包含三维空间和时间的虚拟世界&#xff0c;使得用户对模拟场景产生身临其境的感觉。戴上 VR 眼…

Vue 项目安装依赖报错:errno -4048

笔记&#xff1a; 报错 使用管理换身份打开重新 运行 npm install 就好&#xff01; 报错 原因是 因为 当前 node.js 版本过高 需要降低node版本 重新运行 npm install 就好 降级 Node.js 版本&#xff1a; 根据错误提示&#xff0c;achrinza/node-ipc9.2.2 支持的最高版本是 N…

ora-38301:oracle的回收站临时表异常

最近&#xff0c;私人计算机的oracle意外出现异常错误 ora-38301:can not execute DDL/DML to recycle object. 个人估计可能原因如下&#xff1a; 1. 与使用truncate有关&#xff1b; 2. 可能是因为我的客户端工具有两类&#xff1a;PL/SQL 和 eclipse的Data Source Explo…

习题1.26

解释题&#xff0c;说简单也简单&#xff0c;难在如何表达清楚。 首先解释下代码的变化 (defn expmod[base exp m](cond ( exp 0) 1(even? exp) (mod (square (expmod base (/ exp 2) m)) m):else (mod (* base (expmod base (- exp 1) m)) m)))(defn expmod[base exp m](co…