- 大模型推理的两个阶段:prefill和decode
- prefill阶段处理整个输入序列,并生成第一个输出token,并初始化KV cache
- decode阶段则逐个生成后续的tokens
- KV cache
- 无KV cache的推理过程:每次生成新token时,需要将整个历史序列(prompt+已生成tokens)重新输入模型,并重新计算所有tokens的Key和Value向量
- 有KV cache的推理过程:在prefill阶段计算prompt的Key/Value并缓存,后续decode阶段仅需计算新token的Key/Value,然后将其与缓存的历史Key/Value合并
- 优点:
- 计算量降低:将复杂度从 O ( n 2 ) O(n^2) O(n2)将至 O ( n ) O(n) O(n),每个decode步骤仅需计算新token的Key/Value
- 显存优化:没有的话会存储完整历史序列的中间结果,而用了KV cache仅需存储缓存的Key/Value,通过复用缓存避免了冗余存储
- MQA、GQA:通过多头注意力机制的降本增效,来降低KV cache的大小
- MHA:多头注意力,就是有多个头,每个头有各自的 W Q W^Q WQ、 W V W^V WV、 W K W^K WK来生成自己的Q、K、V,最终会结合在一起。
- MQA:通过在attention机制里共享keys和values来减少KV cache的内容,
- GQA:不是所有的query共享一组KV,而是一个group的query共享一组KV,这样既降低了KV cache,又能满足精度,属于MHA和MQA之间的折中方案
- MLA
上面的MQA和GQA是在缓存多少数量的KV思路上进行优化:直觉上如果缓存的KV个数少,显存占用就少,大模型能力的降低可以通过进一步的训练或者增加FFN/GLU的规模来弥补。
另外一个优化方向就是MLA,它的想法是让缓存的K、V本身变小。原理就是一个MN的矩阵可以近似成两个小矩阵Mk和kN的乘积,也就是把K/V矩阵都拆成两个小矩阵来存储(底层逻辑就是认为低秩空间足够表达高维有用信息)。具体的,K和V共用一个降维矩阵,然后再各自有自己的升维矩阵。即下图中共用 W D K V W^{DKV} WDKV得到相同的 C t K V C_t^{KV} CtKV,以及各自有升维矩阵 W U K W^{UK} WUK和 W U V W^{UV} WUV
把大矩阵拆成两个小矩阵的乘积,缓存是变小了,但是计算量会变多从而失去了缓存的意义。解决方法是把其中的一个小矩阵在计算时,提前放到前面一块沉掉了,然后在每次推理的时候就用这个已经沉掉的值去做推理。这样的话在推理的时候计算量是不会变的。(这里需要一个线性代数的数学推导,总之就是通过一些运算技巧,让计算变少了)
另外还遇到一个问题是RoPE位置编码不兼容MLA算法,解决办法是增加几维向量专门用来做这个RoPE,而不是把所有的向量都用来做RoPE