Flash Attention

ops/2024/12/20 18:58:59/

文章目录

  • Flash Attention: 高效注意力机制解析
    • 什么是 Flash Attention?
    • Flash Attention 与普通 Attention 的对比
    • 为什么选择 Flash Attention?
      • 优点
      • 局限性
    • Flash Attention 的工作原理
      • 核心机制
    • Flash Attention 实现代码
      • 普通 Attention 示例
      • Flash Attention 示例(简化)
  • Flash Attention 的工作原理展示
    • 核心流程
    • 示例:分块计算
      • 输入矩阵
        • 示例输入矩阵

Flash Attention: 高效注意力机制解析

什么是 Flash Attention?

Flash Attention 是一种针对 Transformer 模型 优化的高效注意力计算方法。与传统注意力机制相比,它通过 分块计算显存优化数值稳定性改进,实现了在 长序列任务 中的显著加速,同时大幅降低了显存占用。


Flash Attention 与普通 Attention 的对比

特性普通 AttentionFlash Attention
计算复杂度 O ( n 2 ) O(n^2) O(n2),长序列显存占用高 O ( n 2 ) O(n^2) O(n2),通过分块优化显存使用
显存占用必须存储完整的注意力矩阵 n × n n \times n n×n分块计算避免存储完整矩阵,显存开销显著降低
数值稳定性可能因 Softmax 计算溢出导致不稳定分块归一化(log-sum-exp 技术)保证数值稳定性
适用场景适合短序列任务长序列任务的理想选择,如长文档建模、视频建模

为什么选择 Flash Attention?

优点

  1. 显存高效:避免存储完整的注意力矩阵,支持更长的序列处理。
  2. 计算快速:使用分块和 CUDA 优化,比普通 Attention 加速 2-4 倍。
  3. 数值稳定:改进 Softmax 的实现,支持更大的输入范围。
  4. 适合长序列任务:如 NLP 长文档处理、生物信息学蛋白质序列建模、高分辨率视频分析。

局限性

  1. 实现复杂:依赖 CUDA 核心优化,难以手动实现完整功能。
  2. 硬件要求高:需要现代 GPU 和高效的内存管理。

Flash Attention 的工作原理

核心机制

  1. 传统公式
    Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dk QKT)V

  2. Flash Attention 的优化

    • 分块计算:避免存储完整的 n × n n \times n n×n 矩阵。
    • 块内归一化
      Softmax ( x ) = exp ⁡ ( x − max ⁡ ( x ) ) ∑ exp ⁡ ( x − max ⁡ ( x ) ) \text{Softmax}(x) = \frac{\exp(x - \max(x))}{\sum \exp(x - \max(x))} Softmax(x)=exp(xmax(x))exp(xmax(x))
    • CUDA 并行化:结合 kernel fusion 实现高效矩阵运算。

Flash Attention 实现代码

普通 Attention 示例

import torchdef attention(Q, K, V):d_k = Q.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, V)return output

Flash Attention 示例(简化)

def flash_attention(Q, K, V, block_size=32):batch_size, seq_len, hidden_dim = Q.size()d_k = hidden_dimoutput = torch.zeros_like(Q)for i in range(0, seq_len, block_size):for j in range(0, seq_len, block_size):Q_block = Q[:, i:i+block_size, :]K_block = K[:, j:j+block_size, :]V_block = V[:, j:j+block_size, :]scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))max_scores = torch.max(scores, dim=-1, keepdim=True)[0]scores = scores - max_scoresattention = torch.exp(scores)attention = attention / torch.sum(attention, dim=-1, keepdim=True)output[:, i:i+block_size, :] += torch.matmul(attention, V_block)return output

Flash Attention 的工作原理展示

Flash Attention 的核心优化在于 分块计算(Blockwise Attention Calculation),通过分块减少显存占用,并保持计算效率和数值稳定性。以下是 Flash Attention 的工作流程及分块计算的具体实现细节:


核心流程

Flash Attention 的实现主要分为以下几步:

  1. 输入序列分块

    • 将输入的 QKV 分成小块(block_size),避免一次性计算完整的注意力矩阵。
    • 每个块分别计算局部的点积、Softmax 和加权结果。
  2. 块内注意力计算

    • 对每个块内计算注意力分布,使用数值稳定的 Softmax 优化,避免数值溢出问题。
  3. 逐块累积输出

    • 将分块结果逐步累积,得到最终的全局注意力输出。

示例:分块计算

输入矩阵

假设:

  1. 输入矩阵 Q(Query):形状为 4 × 4,表示序列长度为 4,隐藏维度为 4。
  2. 输入矩阵 K(Key):形状为 4 × 4,与 Q 的形状一致。
  3. 输入矩阵 V(Value):形状为 4 × 4,与 QK 的形状一致。
  4. 分块大小 block_size:假设为 2,表示每次处理 2 个序列块。
示例输入矩阵
Q = [[1, 2, 3, 4],[4, 3, 2, 1],[1, 1, 1, 1],[2, 2, 2, 2]]K = [[1, 0, 1, 0],[0, 1, 0, 1],[1, 1, 1, 1],[2, 2, 2, 2]]V = [[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3],[4, 4, 4, 4]]步骤 1:分块将 Q 和 K 按行进行分块,每块大小为 block_size=2:Q 分块:
Q_1 = [[1, 2, 3, 4],[4, 3, 2, 1]]Q_2 = [[1, 1, 1, 1],[2, 2, 2, 2]]K 分块:
K_1 = [[1, 0, 1, 0],[0, 1, 0, 1]]K_2 = [[1, 1, 1, 1],[2, 2, 2, 2]]V 分块:
V_1 = [[1, 1, 1, 1],[2, 2, 2, 2]]V_2 = [[3, 3, 3, 3],[4, 4, 4, 4]]步骤 2:块间点积计算计算每个块的点积 Q_block × K_block^T,并缩放:计算 ( Q_1 \times K_1^T ):Q_1 × K_1^T = [[1, 2, 3, 4]   ×   [[1, 0, 1, 0]^T    = [[15, 14],[4, 3, 2, 1]]        [0, 1, 0, 1]]     [10, 20]]缩放结果(假设隐藏维度 ( d_k = 4 ),缩放因子为 ( \sqrt{4} = 2 )):Scores = [[15 / 2, 14 / 2],[10 / 2, 20 / 2]] = [[7.5, 7.0],[5.0, 10.0]]计算 ( Q_1 \times K_2^T ):Q_1 × K_2^T = [[1, 2, 3, 4]   ×   [[1, 1, 1, 1]^T    = [[30, 60],[4, 3, 2, 1]]        [2, 2, 2, 2]]     [20, 40]]缩放结果:Scores = [[30 / 2, 60 / 2],[20 / 2, 40 / 2]] = [[15.0, 30.0],[10.0, 20.0]]步骤 3:数值稳定的 Softmax 计算使用 最大值减法 技术对每个分块的 Scores 计算 Softmax,避免数值溢出。对 ( Q_1 \times K_1^T ) 的 Scores 计算:Scores = [[7.5, 7.0],[5.0, 10.0]]最大值减法:[[7.5 - 7.5, 7.0 - 7.5],[5.0 - 10.0, 10.0 - 10.0]] = [[0, -0.5],[-5,  0]]指数计算:[[exp(0), exp(-0.5)],[exp(-5), exp(0)]] = [[1.0, 0.6065],[0.0067, 1.0]]Softmax 归一化:[[1.0 / (1.0 + 0.6065), 0.6065 / (1.0 + 0.6065)],[0.0067 / (0.0067 + 1.0), 1.0 / (0.0067 + 1.0)]] = [[0.622, 0.378],[0.007, 0.993]]步骤 4:加权输出计算使用 Softmax 权重和 V 的块计算加权输出。( Q_1 ) 的加权结果:Output_1 = Softmax(Q_1 × K_1^T) × V_1 = [[0.622, 0.378]   ×   [[1, 1, 1, 1],[2, 2, 2, 2]]     [2.622, 2.622, 2.622, 2.622]]类似地,依次计算 ( Q_2 × K_2^T ) 的输出,逐块累积所有块的结果。最终输出将所有分块的输出累加到最终结果矩阵中,得到完整的注意力结果矩阵 Output。Flash Attention 的优势1.	显存优化:•	普通 Attention 需要存储完整的注意力矩阵,显存占用为 ( O(n^2) )。•	Flash Attention 仅存储分块结果,显存占用为 ( O(n \cdot \text{block_size}) )。2.	计算效率:•	分块计算可以并行化,结合 CUDA 核心优化,速度显著提高。3.	数值稳定性:•	使用块级 Softmax 和最大值归一化,避免长序列点积的数值溢出问题。

http://www.ppmy.cn/ops/143555.html

相关文章

矩阵论:Vector-Valued Linear and Affine Functions介绍:中英双语

最近在翻看 这本书,对其中的一些英文概念做一些记录。 Link:https://web.stanford.edu/~boyd/books.html 中文版 向量值线性函数和仿射函数的详解 在机器学习、数据科学和工程应用中,向量值线性函数和仿射函数是非常重要的数学工具。本…

自动驾驶AVM环视算法--python版本的540投影模式

c语言版本和算法原理的可以查看本人的其他文档。《自动驾驶AVM环视算法--540度全景的算法实现和exe测试demo》本文档进用于展示部分代码的视线,获取方式网盘自行获取(非免费介意勿下载):链接: https://pan.baidu.com/s/19fxwrZ3Bb…

【go每日一题】 实现生产者消费者模式

基本描述 golang使用并发编程,实现一个生产者消费者模式,消费的任务耗时1-3秒,希望最终10秒内能够消费尽可能多的任务 代码 package testimport ("fmt""math/rand""testing""time" )type Consume…

【python实战】-- 解压提取所有指定文件的指定内容

系列文章目录 文章目录 系列文章目录前言一、pandas是什么?1、需求2、程序 总结 前言 一、pandas是什么? 1、需求 指定目录下有若干文件 批量解压 需要汇总包含指定字符的所有文件中的指定数据 2、程序 import os import shutil import zipfile impor…

LIF神经元模型的显隐转换

本文星主将介绍LIF神经元模型的显式和隐式转换(星主看见有论文[1]是这个称呼的,所以本文也称显式和隐式),并得到隐式模型的解析解。注意:理解本文内容需要有一定的微积分基础,如果大家看着数学头疼&#xf…

【GCC】2015: draft-alvestrand-rmcat-congestion-03 机器翻译

腾讯云的一个分析,明显是看了这个论文和草案的 : 最新的是应该是这个 A Google Congestion Control Algorithm for Real-Time Communication draft-ietf-rmcat-gcc-02 下面的这个应该过期了: draft-alvestrand-rmcat-congestion-03

Differential Transformer: 通过差分注意力机制提升大语言模型性能

Transformer模型已经成为大语言模型(LLMs)的标准架构,但研究表明这些模型在准确检索关键信息方面仍面临挑战。今天介绍一篇名叫Differential Transformer的论文,论文的作者观察到一个关键问题:传统Transformer模型倾向…

使用宝塔面板中的Nginx部署前端Vue项目

我相信已经到这一步了,Nginx和宝塔过多的描述我就不说了,直接上干货实操。 第一步:前端项目打包 直接运行Npm run build命令进行打包,会打成一个dist的压缩包 注意:我们前端打包的时候要修改我们连接后端接口的&…