大模型学习笔记 - LLM 之 attention 优化

devtools/2024/9/23 16:43:38/

LLM__0">LLM 注意力机制

  • LLM 注意力机制
    • 1. 注意力机制类型概述
    • 2.Group Query Attention
    • 3.FlashAttention
    • 4. PageAttention

1. 注意力机制类型概述

注意力机制最早来源于Transformer,Transformer中的注意力机制分为2种 Encoder中的 全量注意力机制和 Decoder中的带mask的注意力机制。这两种注意力机制 都是 MultiHeadAttention 由Key,Query, Value 三个矩阵组成。

由于经典的MHA的计算时间和缓存占用量都是O(n^2)级别的(n是序列长度),这就意味着如果序列长度变成原来的 2 倍,显存占用量就是原来的 4 倍,计算时间也是原来的 4 倍。当然,假设并行核心数足够多的情况下,计算时间未必会增加到原来的 4 倍,但是显存的 4 倍却是实实在在的无可避免,这也是之前微调 Bert 的时候时不时就来个 OOM 的原因了。

所以不少工作致力于研究 降低Attention的计算复杂度和缓存大小,从而使复杂度从O(n^2) 降低到O(nlogn) 甚至O(n).

  • 稀疏attention: SparseAttention,Longformer
    • Reformer,Linformer:
    • Linear Attention 思想: Q K t QK^t QKt 这一步我们得到一个nn的矩阵,就是这一步决定了Attention的复杂度是O(n^2);如果没有Softmax,那么就是三个矩阵连乘 Q K t V QK^{t}V QKtV,而矩阵乘法 是满足结合率的,所以我们可以先算 K t V K^{t}V KtV,得到矩阵dd,然后再用Q左乘它,由于d<<n,所以这样算大致的复杂度只有O(n)(就是Q左乘的那一步占主导)。也就说,去掉softmax的Attention的复杂度可以降低到最理想的线性级别。这显然是我们的终极追求:Linear Attention,复杂度为线性级别的Attention.

优化计算量和缓存后,LLM时代,推理速度加速的成为一个问题,于是针对推理慢的开始进行如下优化

  1. IO传输瓶颈: 斯坦福团队发现 影响推理速度的瓶颈不在于计算量,而是IO传输。于是提出了减少IO传输的FlashAttention 1/2/3. FlashAttention论文的目标是尽可能高效地使用SRAM来加快计算速度。
  2. GPU显存瓶颈:研究人眼引入了 PagedAttention,这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说,PagedAttention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。
  3. 减少推理缓存: **GQA(group query attention)**分组注意力机制,在MQA基础上,增加多组Key,Value(但是不是全量),每个head独立拥有Query。

2.Group Query Attention

在自回归解码的标准做法是缓存序列中先前标记的键(K)和值(V) 对,从而加快注意力计算速度。然而,随着上下文窗口或批量大小的增加,多头注意力 (MHA)模型中与 KV 缓存大小相关的内存成本显着增长,所以随着上下文窗口增加,KV缓存大小成为瓶颈,为了扩展上下文,减少注意力机制的计算量和缓存大小,从而研究者开始对全量注意力机制的优化进行研究,目前主流的注意力机制主要分为3种:

  1. MHA(multi-head attention)全量注意力机制,每个head 独立拥有K Q V。
  2. MQA(multi-query attention)多查询注意力机制,多个head共享1组Key,Value,每个head独立拥有Query。
    1. 由于只是用一个 key 和value,大大加快解码推断的速度,但是可能导致质量下降。
    2. 目前ChatGLM2-6B使用的是这个
  3. GQA(group query attention)分组注意力机制,在MQA基础上,增加多组Key,Value(但是不是全量),每个head独立拥有Query。
    1. LLaMA2 和 Mistral采用的是这个。
    2. 属于1和2的折中,KV个数在1-head 中间.
    3. GQA 变体在大多数评估任务上的表现与 MHA 基线相当,并且平均优于 MQA 变体

在这里插入图片描述

3.FlashAttention

4. PageAttention

tobe added


http://www.ppmy.cn/devtools/98521.html

相关文章

生活垃圾填埋场污染监测:新标准下的技术革新与环境保护

随着城市化进程的加速&#xff0c;生活垃圾产生量急剧增加&#xff0c;如何有效处理并控制其带来的环境污染成为亟待解决的问题。近日&#xff0c;生态环境部发布了新修订的《生活垃圾填埋场污染控制标准》&#xff08;GB 16889-2024&#xff09;&#xff0c;将自2024年9月1日起…

K8S资源之PVPVC

概念 类似于Docker的数据卷挂载&#xff0c;将Pod中重要的文件挂载到宿主机上&#xff0c;如果Pod发生崩溃等情况自愈时&#xff0c;保证之前存储的数据没有丢失。 如上图中&#xff0c;将各个Pod中的目录挂载到存储层&#xff0c;如果Pod宕机后自愈均从存储层获取之前的数据…

【python】灰色预测 GM(1,1) 模型

文章目录 前言python代码 前言 用 python 复刻上一篇博客的 Matlab 代码。 【学习笔记】灰色预测 GM(1,1) 模型 —— Matlab python代码 # %% import numpy as np import statsmodels.api as sm import matplotlib.pyplot as plt from matplotlib.pylab import mplmpl.rcPa…

CodeLLDB的快速安装

1、CodeLLDB很难安装 ‌‌CodeLLDB插件是一个基于‌LLDB的调试器插件&#xff0c;专为‌Visual Studio Code设计&#xff0c;旨在提供类似于传统集成开发环境&#xff08;IDE&#xff09;的调试体验。‌ 它支持‌C、‌C和‌Objective-C程序的调试&#xff0c;包括设置断点、查…

python mysql insert 时 获取 自增 id的值

在MySQL中&#xff0c;当你使用INSERT语句插入一行数据到拥有自增主键的表时&#xff0c;你可以通过使用LAST_INSERT_ID()函数来获取这个新的自增ID值。 以下是一个简单的例子&#xff1a; 假设你有一个表users&#xff0c;它有一个自增的主键id&#xff0c;和其他一些字段比…

Linux下opencv报错 undefined reference to cv::imread cv::Mat

如果你是和libtorch一起使用&#xff0c;那么请你继续&#xff0c;否则该篇文章不适合你。 正文 在https://pytorch.org/下 下载的时候要选择Cxx11 ABI版 随后正常配置就可以了

用py获取显卡的占用率

这是什么 这是一个py 编写的程序&#xff0c;功能上面是用于获取 NVIDIA 显卡的占用率&#xff0c;并通过串口将其发送出去。同时&#xff0c;程序也会读取串口接收到的数据并显示在终端上&#xff0c;这样方便调试。 注意 因为我用的是N卡所以这个程序限制N卡使用&#xff0…

达梦数据库表结构导出到 Excel 教程

在数据库开发和维护中&#xff0c;导出数据表结构是常见的需求之一&#xff0c;特别是在进行数据库文档化、系统迁移、版本控制等工作时。通过导出表结构到 Excel&#xff0c;我们可以方便地查看、分析和分享表结构信息。在本文中&#xff0c;我将结合达梦数据库的相关 SQL 查询…