从Online Softmax到FlashAttention

embedded/2025/3/17 5:16:27/

目录

    • 前言
    • 0. Abstract
    • 1. The Self-Attention
    • 2. (Safe) Softmax
    • 3. Online Softmax
    • 4. FlashAttention
    • 结语
    • 参考

前言

最近在学习 FlashAttention,看到一份不错的手稿分享下🤗

manuscript:From Online Softmax to FlashAttention

0. Abstract

FlashAttention 的关键创新是使用类似于 Online Softmax 的思想来 tile 分块 self-attention 的计算,这样可以融合整个多头注意力层,而无需多次重复访问 GPU 的全局内存来获取临时变量和注意力分数矩阵 A A A。本文将简要解释如何 tiling 分块 self-attention 的计算,以及如何从 online softmax 中推导出 FlashAttention 计算

1. The Self-Attention

Self-Attention 的计算可以描述为:

O = s o f t m a x ( Q K T ) V \begin{equation} O = \mathrm{softmax}\left(QK^T\right)V \end{equation} O=softmax(QKT)V

其中 Q , K , V , O ∈ R L × D Q,K,V,O \in \mathbb{R}^{L\times D} Q,K,V,ORL×D L L L 是序列长度, D D D 是每个头的维度,softmax 将按列应用于 Q K T QK^T QKT

Note:这里我们忽略了多头、多 batch,因为在这些维度上的计算是完全并行的,也就是说我们只关注单头、单 batch 的计算就行。另外为了简单起见,我们也忽略了注意力掩码以及缩放因子 1 D \frac{1}{\sqrt{D}} D 1

计算 self-attention 的标准方法是将其分解为以下几个阶段:

X = Q K T A = s o f t m a x ( X ) O = A V \begin{align} X & = QK^T \\ A & = \mathrm{softmax}(X) \\ O & = AV \end{align} XAO=QKT=softmax(X)=AV

我们称 X X X 矩阵为 pre-softmax logits, A A A 矩阵为注意力分数(attention score), O O O 矩阵为输出

FlashAttention 的一个惊人之处在于,我们不需要在全局内存(global memory)上实现 X X X A A A 矩阵,而是将公式 (1) 中的整个计算融合到单个 CUDA kernel 中。这要求我们设计一种能够精心管理片上内存(on-chip memory)的算法,因为 NVIDIA GPU 的共享内存(shared memory)非常小

对于矩阵乘法等经典算法,我们通常会使用 tiling 技术来确保片上内存不超过硬件限制。下图提供了一个例子,在 kernel 执行期间,无论矩阵形状如何,只有 3 T 2 3T^2 3T2 个元素存储在片上。这种 tiling 方法之所以有效,是因为加法是关联的,允许将整个矩阵乘法分解为许多 tile-wise 矩阵乘法的总和

然而,Self-Attention 包含一个不直接关联的 softmax 运算符,因此很难像下图那样简单地对 Self-Attention 进行 tile,那有没有办法让 softmax 具有关联性呢?🤔

在这里插入图片描述

上图简要说明了如何对矩阵乘法 C = A × B C=A\times B C=A×B 的输入和输出矩阵进行 tile(分块),矩阵被划分为 T × T T\times T T×T 个小块。对于每个输出块,我们从左到右遍历 A A A 中的相关块,从上到下遍历 B B B 中的相关块,并将它们的值从全局内存(global memory)加载到片上内存(on-chip memory)(以蓝色表示,整体片上内存占用为 O ( T 2 ) O(T^2) O(T2)

接着逐块进行矩阵乘法,对于位置 ( i , j ) (i,j) (i,j),我们先从片上内存中加载块内所有 i i i 行的元素即 A [ i , k ] A[i,k] A[i,k] 以及所有 j j j 列的元素即 B [ k , j ] B[k,j] B[k,j](以红色表示),然后利用 A [ i , k ] A[i,k] A[i,k] B [ k , j ] B[k,j] B[k,j] 计算得到片上内存中的 C [ i , j ] C[i,j] C[i,j]。以此循环,直到完成一个块的计算后,我们再将片上 C C C 块的结果写回主内存并继续处理下一个块

实际 tiling 的应用要复杂得多,大家可以参考:cutlass 在 A100 上实现矩阵乘法

2. (Safe) Softmax

我们先回顾一下 softmax 算子,下面是 softmax 计算的通用公式:

s o f t m a x ( { x 1 , … , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N \begin{equation} \mathrm{softmax}(\{x_1,\ldots,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}\right\}_{i=1}^N \end{equation} softmax({x1,,xN})={j=1Nexjexi}i=1N

值得注意的是,当 x i x_i xi 比较大时 e x i e^{x_i} exi 很容易溢出,例如 float16 可以支持的最大值是 65536,这意味着当 x ⩾ 11 x \geqslant 11 x11 时, e x e^x ex 将超过 float16 的有效范围

为了解决这个问题,像 pytorch、tensorflow 等框架通常会使用一种被称为 “safe” 的 softmax 计算方法:

e x i ∑ j = 1 N e x j = e x i − m ∑ j = 1 N e x j − m \begin{equation} \frac{e^{x_{i}}}{\sum_{j=1}^{N}e^{x_{j}}}=\frac{e^{x_{i}-m}}{\sum_{j=1}^{N}e^{x_{j}-m}} \end{equation} j=1Nexjexi=j=1Nexjmexim

其中 m = max ⁡ j = 1 N ( x j ) m=\max_{j=1}^{N}(x_j) m=maxj=1N(xj)

我们将 x i x_i xi 减去它们中的最大值,这样可以确保每个元素 x i − m ⩽ 0 x_i-m \leqslant 0 xim0,从而再做指数运算时可以确保其值不会溢出

我们可以将 safe softmax 的计算总结为下面的 3-pass 算法:

Algorithm 3-pass safe softmax

NOTATIONS

  • { m i } \{m_i\} {mi} max ⁡ j = 1 i { x j } \max_{j=1}^i\{x_j\} maxj=1i{xj},初始值 m 0 = − ∞ m_0=-\infty m0=
  • { d i } \{d_i\} {di} ∑ j = 1 i e x j − m N \sum_{j=1}^{i}e^{x_{j}-m_{N}} j=1iexjmN,初始值 d 0 = 0 d_0=0 d0=0 d N d_N dN 是 safe softmax 的分母
  • { a i } \{a_i\} {ai}:最终的 softmax

BODY

for i ← 1 , N i\leftarrow1,N i1,N do

m i ← max ⁡ ( m i − 1 , x i ) \begin{equation} m_i \leftarrow \max (m_{i-1},x_i) \end{equation} mimax(mi1,xi)

end

for i ← 1 , N i\leftarrow1,N i1,N do

d i ← d i − 1 + e x i − m N \begin{equation} d_i \leftarrow d_{i-1} + e^{x_i-m_N} \end{equation} didi1+eximN

end

for i ← 1 , N i\leftarrow1,N i1,N do

a i ← e x i − m N d N \begin{equation} a_i \leftarrow \frac{e^{x_i-m_N}}{d_N} \end{equation} aidNeximN

end

该算法要求我们对 [ 1 , N ] [1,N] [1,N] 进行 3 次迭代,在 Transformer 的 self-attention 中,每个 x i x_i xi 是通过 Q K T QK^T QKT 计算得到的。如果我们无法将所有的 x i x_i xi 都缓存到片上内存(SRAM)中(实际上我们也没有足够大的片上内存 SRAM 来容纳所有的 x i x_i xi),那么在每次迭代时不得不访问 Q , K Q,K Q,K 以动态重新计算 x i x_i xi,这种频繁的访问会导致大量的 I/O 操作,从而降低整体效率

3. Online Softmax

如果我们能够在一个循环中融合公式 (7)、(8)、(9),那么就可以将 global memory 的访问时间降低 3 倍,不幸的是,我们不能在同一个循环中融合公式 (7)、(8),因为公式 (8) 的计算需要依赖 m N m_N mN,而 m N m_N mN 只有在第一个循环完成之后才能够确定

既然数据之间存在着依赖关系,那我们可以创建另外一个序列 d i ′ : = ∑ j = 1 i e x j − m i d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{i}} di:=j=1iexjmi 作为原始序列 d i ′ : = ∑ j = 1 i e x j − m N d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{N}} di:=j=1iexjmN 的替代,以消除对 N N N 的依赖,并且这两个序列的第 N N N 项是相同的即 d N = d N ′ d_N=d^{\prime}_N dN=dN,因此我们可以安全地用 d N ′ d^{\prime}_N dN 替换公式 (9) 中的 d N d_N dN

Note:在数学和计算机科学中,符号 : = := := 用于表示 “定义为” 或 “被定义为”,例如当我们写下 x : = y x:=y x:=y 时,我们是在说明 “我们将 x x x 定义为 y y y” 或 “ x x x 等于 y y y(这是它的定义)”

此外,我们还可以找到 d i ′ d^{\prime}_i di d i − 1 ′ d^{\prime}_{i-1} di1 之间的递归关系:

d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{equation} \begin{aligned} d_{i}^{\prime} & =\sum_{j=1}^ie^{x_j-m_i} \\ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \\ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ & =d_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \end{equation} di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi

这个递归公式只依赖于 m i m_i mi m i − 1 m_{i-1} mi1,此外我们还可以在一个循环中同时计算 m j m_j mj d j ′ d^{\prime}_j dj

Algorithm 2-pass online softmax

for i ← 1 , N i\leftarrow 1,N i1,N do

m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} m_i & \leftarrow \max (m_{i-1}, x_i) \\ d^{\prime}_i & \leftarrow d^{\prime}_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} midimax(mi1,xi)di1emi1mi+eximi

end

for i ← 1 , N i\leftarrow 1,N i1,N do

a i ← e x i − m N d N ′ a_i \leftarrow \frac{e^{x_i - m_N}}{d^{\prime}_N} aidNeximN

end

这是 Online Softmax 论文中提出的算法,但它仍然需要两次循环才能完成 softmax 的计算,我们能否将循环次数减少到 1 次来最小化全局的 I/O 呢?🤔

4. FlashAttention

不幸的是,对于 softmax 来说,答案是否定的,但在 Self-Attention 中,我们最终的目标并不是注意力分数矩阵 A = s o f t m a x ( Q K T ) A=\mathrm{softmax}(QK^T) A=softmax(QKT),而是输出矩阵 O = A × V O=A \times V O=A×V,既然无法找到 softmax 的一次递归形式,那我们换个思路思考下能否找到矩阵 O O O 的一次递归形式呢?

下面我们尝试下先将 Self-Attention 计算的第 k k k 行公式转化为递归算法:

Note:由于所有行的计算都是独立的,为了简单起见,这里我们只解释一行的计算

Algorithm Multi-pass Self-Attention

NOTATIONS

  • Q [ k , : ] Q[k,:] Q[k,:]:查询矩阵 Q Q Q 的第 k k k 行向量
  • K T [ : , i ] K^T[:,i] KT[:,i]:键矩阵 K T K^T KT 的第 i i i 列向量
  • O [ k , : ] O[k,:] O[k,:]:输出矩阵 O O O 的第 k k k
  • V [ i , : ] V[i,:] V[i,:]:值矩阵 V V V 的第 i i i
  • o i \bm{o}_i oi ∑ j = 1 i a j V [ j , : ] \sum_{j=1}^ia_jV[j,:] j=1iajV[j,:],存储 A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:] A[k,:i]×V[:i,:] 结果的行向量

BODY

for i ← 1 , N i\leftarrow 1,N i1,N do

x i ← Q [ k , : ] K T [ : , i ] m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} x_i & \leftarrow Q[k,:]K^T[:,i] \\ m_i & \leftarrow \max (m_{i-1},x_i) \\ d^{\prime}_i & \leftarrow d^{\prime}_{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \end{aligned} ximidiQ[k,:]KT[:,i]max(mi1,xi)di1emi1mi+eximi

end

for i ← 1 , N i\leftarrow 1,N i1,N do

a i ← e x i − m N d N ′ o i ← o i − 1 + a i V [ i , : ] \begin{align} a_i & \leftarrow \frac{e^{x_i-m_N}}{d^{\prime}_N} \\ \bm{o}_i & \leftarrow \bm{o}_{i-1} + a_iV[i,:] \end{align} aioidNeximNoi1+aiV[i,:]

end

O [ k , : ] ← o N O[k,:] \leftarrow \bm{o}_N O[k,:]oN

我们将公式 (12) 中的 a i a_i ai 替换公式 (11) 中的定义,有:

o i : = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) \begin{align} \bm{o}_i := \sum_{j=1}^i\left(\frac{e^{x_j-m_N}}{d_N^{\prime}}V[j,:]\right) \end{align} oi:=j=1i(dNexjmNV[j,:])

该公式仍然依赖 m N m_N mN d N d_N dN 变量,这两个值在前一个循环完成之前无法确定。但我们可以再次使用第 3 节中的替代技巧,即创建替代序列 o ′ \bm{o}^{\prime} o

o i ′ : = ( ∑ j = 1 i e x j − m i d i ′ V [ j , : ] ) \bm{o}_i^{\prime}:=\left(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right) oi:=(j=1idiexjmiV[j,:])

o \bm{o} o o ′ \bm{o}^{\prime} o 的第 N N N 个元素相同即 o N ′ = o N \bm{o}^{\prime}_N=\bm{o}_N oN=oN ,并且我们可以发现 o i ′ \bm{o}^{\prime}_i oi o i − 1 ′ \bm{o}^{\prime}_{i-1} oi1 之间存在如下的递归关系:

o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{equation} \begin{aligned} \bm{o}^{\prime}_{i}& = \sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d^{\prime}_{i}}V[j, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i}}}{d^{\prime}_{i} }V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}_{ i-1}}\frac{e^{x_{j}-m_{i}}}{e^{x_{j}-m_{i-1}}}\frac{d^{\prime}_{i-1}}{d^{ \prime}_{i}}V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}_{ i-1}}V[j, : ]\right) \frac{d^{\prime}_{i-1}}{d^{\prime}_{i}}e^{m_{i-1}-m_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \bm{o}^{\prime}_{i-1}\frac{d^{\prime}_{i-1} e^{m_{i-1}-m_{ i}}}{d^{\prime}_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ] \end{aligned} \end{equation} oi=j=1idiexjmiV[j,:]=(j=1i1diexjmiV[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1exjmi1exjmididi1V[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1V[j,:])didi1emi1mi+dieximiV[i,:]=oi1didi1emi1mi+dieximiV[i,:]

它仅依赖于 d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}_i,\ d^{\prime}_{i-1}, \ m_i, \ m_{i-1}, \ x_i di, di1, mi, mi1, xi ,因此我们可以在单个循环中融合 Self-Attention 中的所有计算

Algorithm FlashAttention

for i ← 1 , N i \leftarrow 1,N i1,N do

x i ← Q [ k , : ] K T [ : , i ] m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} {x_{i}} & \leftarrow {Q[k,:]} {K^{T}[:,i]}\\ {m_{i}} & \leftarrow \max\left( {m_{i-1},x_{i}} \right)\\ {d^{\prime}_{i}} & \leftarrow {d^{\prime}_{i-1}e^{m_{i-1}-m_{i}}}+ {e^{x_{i}-m_{i}}}\\ {o^{\prime}_{i}} & \leftarrow {o^{\prime}_{i-1}}\frac{ {d^{\prime}_{i-1}e^{m_{i-1}-m_{i}}}}{ { d^{\prime}_{i}}}+\frac{ {e^{x_{i}-m_{i}}}}{ {d^{\prime}_{i}}} {V[i,:]}\end{aligned} ximidioiQ[k,:]KT[:,i]max(mi1,xi)di1emi1mi+eximioi1didi1emi1mi+dieximiV[i,:]

end

O [ k , : ] ← o N ′ O[k,:] \leftarrow \bm{o}^{\prime}_N O[k,:]oN

状态量 x i , m i , d i ′ , o i ′ x_i, \ m_i, \ d^{\prime}_i, \bm{o}^{\prime}_i xi, mi, di,oi 占用的内存都很小,可以非常轻松的放入 GPU 的 shared memory 中。另外由于此算法中的所有操作都是关联的,因此它与 tiling 兼容,如果我们逐个 tiling 计算,则该算法可以表示为:

Algorithm FlashAttention(Tiling)

NEW NOTATIONS

  • b b b:tile 的 block size 大小
  • #tiles \text{\#tiles} #tiles:每行 tile 的数量, N = b × #tiles N= b \times \text{\#tiles} N=b×#tiles
  • x i \bm{x}_i xi:存储第 i i i 个 tile 的 Q [ k ] K T Q[k]K^T Q[k]KT 值的向量 [ ( i − 1 ) b : i b ] [(i-1)b:ib] [(i1)b:ib]
  • m i ( l o c a l ) m_i^{\mathrm{(local)}} mi(local) x i \bm{x}_i xi 内部的局部最大值

BODY

for i ← 1 , #tile i \leftarrow 1,\text{\#tile} i1,#tile do

x i ← Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i (local) = max ⁡ j = 1 b ( x i [ j ] ) m i ← max ⁡ ( m i − 1 , m i (local) ) d i ′ ← d i − 1 ′ e m i − 1 − m i + ∑ j = 1 b e x i [ j ] − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j = 1 b e x i [ j ] − m i d i ′ V [ j + ( i − 1 ) b , : ] \begin{split}\bm{x}_{i}&\ \leftarrow\ Q[k, : ] K^{T}[:,(i-1) b : i b]\\ m_{i}^{\text{(local)}}&\ =\ \max_{j=1}^{b} (\bm{x}_{i} [j]) \\ m_{i}&\ \leftarrow\ \max\big(m_{i-1},m_{i}^{\text{(local)}} \big)\\ d_{i}^{\prime}&\ \leftarrow\ d_{i-1}^{\prime} e^{m_{i-1}-m_{i}} + \sum_{j=1}^{b} e^{\bm{x}_{i}[j]-m_{i}}\\ \bm{o}_{i}^{\prime}&\ \leftarrow\ \bm{o}_{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}} + \sum_{j=1}^{b} \frac{e^{\bm{x}_{i}[j]-m_{i}}}{d_{i}^{\prime}}V[j+(i-1) b, : ]\end{split} ximi(local)midioi  Q[k,:]KT[:,(i1)b:ib] = j=1maxb(xi[j])  max(mi1,mi(local))  di1emi1mi+j=1bexi[j]mi  oi1didi1emi1mi+j=1bdiexi[j]miV[j+(i1)b,:]

end

O [ k , : ] ← o N / b ′ O[k,:] \leftarrow \bm{o}^{\prime}_{N/b} O[k,:]oN/b

下图说明了如何将该算法应用到硬件上:

在这里插入图片描述

上图说明了 FlashAttention 在硬件上的计算方式,其中蓝色块表示在 SRAM 中的 tile,而红色块对应于 tile 中的第 i i i 行, L L L 表示序列长度, D D D 表示每个头的维度, B B B 表示 block size 的大小

值得注意的是,整体 SRAM 的内存占用仅仅取决于 B B B D D D,与 L L L 无关,因此,该算法可以扩展到长上下文而不会遇到内存问题(GPU 的共享内存很小,例如 H100 架构中为 228kb/SM)。在计算过程中,我们从左到右遍历 K T K^T KT A A A,从上到下遍历 V V V,并相应地更新 m , d , O m,\ d, \ O m, d, O 的状态

结语

这篇文章主要分享了 FlashAttention 的思想,FlashAttention 想要做的就是在单个循环中完成 self-attention 的整个计算

我们先从 safe softmax 出发分析了 safe softmax 算法需要循环迭代 3 次,第一次循环求取序列中的最大值 m N m_N mN,第二次循环求取 safe softmax 的分母 d N d_N dN,第三次循环求取最终的 softmax a i a_i ai。由于 SRAM 片上内存空间有限,我们无法将所有的 x i x_i xi 缓存,导致我们在每次迭代时需要重新访问 Q , K Q,K Q,K 以动态重新计算 x i x_i xi,这造成了大量的 I/O 操作,我们并不希望看到

在 online softmax 中提出了一个有意思的想法,那就是创建一个新的序列 d i ′ d_i^{\prime} di 来替换原始的 d i d_i di,主要区别在于将原来 d i d_i di 定义中的 m N m_N mN 替换为了 m i m_i mi,以消除对 N N N 的依赖,并且我们发现 d N = d N ′ d_N=d_N^{\prime} dN=dN,那么整个 safe softmax 的计算就只需要迭代 2 次即可,第一次循环迭代可以同时计算 m i m_i mi d i ′ d_i^{\prime} di

FlashAttention 想要把 self-attention 中的循环次数降低到一次,但是发现对于其中的 softmax 而言无法实现,因此它另辟蹊径试图寻找输出矩阵 O O O 的一次迭代形式。经过推导我们发现输出行向量 o i \bm{o}_i oi 仍然依赖于 m N m_N mN d N d_N dN 变量,这两个值在前一个循环完成之前无法确定,因此仍需要两次循环

FlashAttention 借助 online softmax 的替代思想,将序列 o \bm{o} o 替换为 o ′ \bm{o}^{\prime} o,替换之后发现 o i ′ \bm{o}_i^{\prime} oi 不再依赖于 m N m_N mN d N d_N dN 等变量,而仅依赖于 d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}_i,\ d^{\prime}_{i-1}, \ m_i, \ m_{i-1}, \ x_i di, di1, mi, mi1, xi,从而可以在单个循环中融合 self-attention 中的所有计算

最后,由于 flashattention 算法中的所有操作都是关联的,因此可以逐个 tiling 分块计算,并且整体 SRAM 的内存占用与序列长度无关,因此可以扩展到长上下文而不用担心遇到内存问题

OK,以上就是关于这篇手稿的全部内容了,大家感兴趣的可以看看,作为 FlashAttention 的入门还是没有问题的😄

参考

  • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  • Gtc 2020: developing cuda kernels to push tensor cores to the absolute limit on nvidia a100
  • Online normalizer calculation for softmax

http://www.ppmy.cn/embedded/173253.html

相关文章

golang开发支持onlyoffice的token功能

一直都没去弄token这块,想着反正docker run的时候将jwt置为false即可。 看了好多文章,感觉可以试试,但是所有文件几乎都没说思路。 根据我的理解和成功的调试,思路是: 我们先定义2个概念,一个是文档下载…

电子招采软件系统,如何实现10年可追溯审计

一、在当前经济环境下,中小企业面临着巨大的生存压力,传统产业的数字化转型迫在眉睫。AI技术为企业的低成本高效发展提供了新机会,混合办公成为新常态,数据安全法的深入落实则进一步推动企业重视数据安全。区块链存证技术凭借独特…

Mac下安装Zed以及Zed对MCP(模型上下文协议)的支持

Zed是当前新流行的一种编辑器,支持MCP(模型上下文协议) Mac下安装Zed比较简单,直接有安装包,在这里: brew install --cask zedMac Monterey下是可以安装上的,亲测有效。 配置 使用CtrlShiftP…

游戏引擎学习第157天

今天的计划 目标是完整制作一款游戏,从头到尾的开发过程完全展示。过程中没有使用任何游戏引擎或库,目的是展示一个全面的游戏开发过程,包括每一个细节,从最基础的像素开始,直到最终的视觉效果。在整个过程中&#xf…

深入理解 Reactor Netty 线程配置及启动命令设置

一、引言 在使用 Spring Boot 开发基于 Reactor Netty 的应用程序时,合理配置 Reactor Netty 的线程参数对于优化应用性能至关重要。本文将详细介绍 reactor.netty.ioSelectCount 和 reactor.netty.ioWorkerCount 这两个关键参数的作用、不同设置值的影响&#xff0…

基于stm32的视觉物流机器人

标题:基于stm32的视觉物流机器人 内容:1.摘要 本文围绕基于STM32的视觉物流机器人展开研究。背景是随着物流行业的快速发展,对物流自动化和智能化的需求日益增长,传统物流方式效率低且成本高。目的是设计一款基于STM32的视觉物流机器人,以提…

prompt工程起步

1.手工提示词 有关CLIP和ActionClip的手工特征,也是一个进步。通过给标签填入不同的修饰语当中,组成一段话来,来增强语义理解 def text_prompt(data):text_aug [f"a photo of action {{}}", f"a picture of action {{}}", f"Human acti…

QT:非模态使用WA_DeleteOnClose避免内存泄漏

connect(ui->actionnewFile,&QAction::triggered,this,[](){QDialog*dlg new QDialog(this);//dlg.exec();dlg->show();dlg->setAttribute(Qt::WA_DeleteOnClose);qDebug()<<"打开对话框";}); 1. QDialog* dlg new QDialog(this); - 创建了…