目录
- 前言
- 0. Abstract
- 1. The Self-Attention
- 2. (Safe) Softmax
- 3. Online Softmax
- 4. FlashAttention
- 结语
- 参考
前言
最近在学习 FlashAttention,看到一份不错的手稿分享下🤗
manuscript:From Online Softmax to FlashAttention
0. Abstract
FlashAttention 的关键创新是使用类似于 Online Softmax 的思想来 tile 分块 self-attention 的计算,这样可以融合整个多头注意力层,而无需多次重复访问 GPU 的全局内存来获取临时变量和注意力分数矩阵 A A A。本文将简要解释如何 tiling 分块 self-attention 的计算,以及如何从 online softmax 中推导出 FlashAttention 计算
1. The Self-Attention
Self-Attention 的计算可以描述为:
O = s o f t m a x ( Q K T ) V \begin{equation} O = \mathrm{softmax}\left(QK^T\right)V \end{equation} O=softmax(QKT)V
其中 Q , K , V , O ∈ R L × D Q,K,V,O \in \mathbb{R}^{L\times D} Q,K,V,O∈RL×D, L L L 是序列长度, D D D 是每个头的维度,softmax 将按列应用于 Q K T QK^T QKT
Note:这里我们忽略了多头、多 batch,因为在这些维度上的计算是完全并行的,也就是说我们只关注单头、单 batch 的计算就行。另外为了简单起见,我们也忽略了注意力掩码以及缩放因子 1 D \frac{1}{\sqrt{D}} D1
计算 self-attention 的标准方法是将其分解为以下几个阶段:
X = Q K T A = s o f t m a x ( X ) O = A V \begin{align} X & = QK^T \\ A & = \mathrm{softmax}(X) \\ O & = AV \end{align} XAO=QKT=softmax(X)=AV
我们称 X X X 矩阵为 pre-softmax logits, A A A 矩阵为注意力分数(attention score), O O O 矩阵为输出
FlashAttention 的一个惊人之处在于,我们不需要在全局内存(global memory)上实现 X X X 和 A A A 矩阵,而是将公式 (1) 中的整个计算融合到单个 CUDA kernel 中。这要求我们设计一种能够精心管理片上内存(on-chip memory)的算法,因为 NVIDIA GPU 的共享内存(shared memory)非常小
对于矩阵乘法等经典算法,我们通常会使用 tiling 技术来确保片上内存不超过硬件限制。下图提供了一个例子,在 kernel 执行期间,无论矩阵形状如何,只有 3 T 2 3T^2 3T2 个元素存储在片上。这种 tiling 方法之所以有效,是因为加法是关联的,允许将整个矩阵乘法分解为许多 tile-wise 矩阵乘法的总和
然而,Self-Attention 包含一个不直接关联的 softmax 运算符,因此很难像下图那样简单地对 Self-Attention 进行 tile,那有没有办法让 softmax 具有关联性呢?🤔
上图简要说明了如何对矩阵乘法 C = A × B C=A\times B C=A×B 的输入和输出矩阵进行 tile(分块),矩阵被划分为 T × T T\times T T×T 个小块。对于每个输出块,我们从左到右遍历 A A A 中的相关块,从上到下遍历 B B B 中的相关块,并将它们的值从全局内存(global memory)加载到片上内存(on-chip memory)(以蓝色表示,整体片上内存占用为 O ( T 2 ) O(T^2) O(T2))
接着逐块进行矩阵乘法,对于位置 ( i , j ) (i,j) (i,j),我们先从片上内存中加载块内所有 i i i 行的元素即 A [ i , k ] A[i,k] A[i,k] 以及所有 j j j 列的元素即 B [ k , j ] B[k,j] B[k,j](以红色表示),然后利用 A [ i , k ] A[i,k] A[i,k] 和 B [ k , j ] B[k,j] B[k,j] 计算得到片上内存中的 C [ i , j ] C[i,j] C[i,j]。以此循环,直到完成一个块的计算后,我们再将片上 C C C 块的结果写回主内存并继续处理下一个块
实际 tiling 的应用要复杂得多,大家可以参考:cutlass 在 A100 上实现矩阵乘法
2. (Safe) Softmax
我们先回顾一下 softmax 算子,下面是 softmax 计算的通用公式:
s o f t m a x ( { x 1 , … , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N \begin{equation} \mathrm{softmax}(\{x_1,\ldots,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}\right\}_{i=1}^N \end{equation} softmax({x1,…,xN})={∑j=1Nexjexi}i=1N
值得注意的是,当 x i x_i xi 比较大时 e x i e^{x_i} exi 很容易溢出,例如 float16 可以支持的最大值是 65536,这意味着当 x ⩾ 11 x \geqslant 11 x⩾11 时, e x e^x ex 将超过 float16 的有效范围
为了解决这个问题,像 pytorch、tensorflow 等框架通常会使用一种被称为 “safe” 的 softmax 计算方法:
e x i ∑ j = 1 N e x j = e x i − m ∑ j = 1 N e x j − m \begin{equation} \frac{e^{x_{i}}}{\sum_{j=1}^{N}e^{x_{j}}}=\frac{e^{x_{i}-m}}{\sum_{j=1}^{N}e^{x_{j}-m}} \end{equation} ∑j=1Nexjexi=∑j=1Nexj−mexi−m
其中 m = max j = 1 N ( x j ) m=\max_{j=1}^{N}(x_j) m=maxj=1N(xj)
我们将 x i x_i xi 减去它们中的最大值,这样可以确保每个元素 x i − m ⩽ 0 x_i-m \leqslant 0 xi−m⩽0,从而再做指数运算时可以确保其值不会溢出
我们可以将 safe softmax 的计算总结为下面的 3-pass 算法:
Algorithm 3-pass safe softmax
NOTATIONS
- { m i } \{m_i\} {mi}: max j = 1 i { x j } \max_{j=1}^i\{x_j\} maxj=1i{xj},初始值 m 0 = − ∞ m_0=-\infty m0=−∞
- { d i } \{d_i\} {di}: ∑ j = 1 i e x j − m N \sum_{j=1}^{i}e^{x_{j}-m_{N}} ∑j=1iexj−mN,初始值 d 0 = 0 d_0=0 d0=0, d N d_N dN 是 safe softmax 的分母
- { a i } \{a_i\} {ai}:最终的 softmax 值
BODY
for i ← 1 , N i\leftarrow1,N i←1,N do
m i ← max ( m i − 1 , x i ) \begin{equation} m_i \leftarrow \max (m_{i-1},x_i) \end{equation} mi←max(mi−1,xi)
end
for i ← 1 , N i\leftarrow1,N i←1,N do
d i ← d i − 1 + e x i − m N \begin{equation} d_i \leftarrow d_{i-1} + e^{x_i-m_N} \end{equation} di←di−1+exi−mN
end
for i ← 1 , N i\leftarrow1,N i←1,N do
a i ← e x i − m N d N \begin{equation} a_i \leftarrow \frac{e^{x_i-m_N}}{d_N} \end{equation} ai←dNexi−mN
end
该算法要求我们对 [ 1 , N ] [1,N] [1,N] 进行 3 次迭代,在 Transformer 的 self-attention 中,每个 x i x_i xi 是通过 Q K T QK^T QKT 计算得到的。如果我们无法将所有的 x i x_i xi 都缓存到片上内存(SRAM)中(实际上我们也没有足够大的片上内存 SRAM 来容纳所有的 x i x_i xi),那么在每次迭代时不得不访问 Q , K Q,K Q,K 以动态重新计算 x i x_i xi,这种频繁的访问会导致大量的 I/O 操作,从而降低整体效率
3. Online Softmax
如果我们能够在一个循环中融合公式 (7)、(8)、(9),那么就可以将 global memory 的访问时间降低 3 倍,不幸的是,我们不能在同一个循环中融合公式 (7)、(8),因为公式 (8) 的计算需要依赖 m N m_N mN,而 m N m_N mN 只有在第一个循环完成之后才能够确定
既然数据之间存在着依赖关系,那我们可以创建另外一个序列 d i ′ : = ∑ j = 1 i e x j − m i d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{i}} di′:=∑j=1iexj−mi 作为原始序列 d i ′ : = ∑ j = 1 i e x j − m N d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{N}} di′:=∑j=1iexj−mN 的替代,以消除对 N N N 的依赖,并且这两个序列的第 N N N 项是相同的即 d N = d N ′ d_N=d^{\prime}_N dN=dN′,因此我们可以安全地用 d N ′ d^{\prime}_N dN′ 替换公式 (9) 中的 d N d_N dN
Note:在数学和计算机科学中,符号 : = := := 用于表示 “定义为” 或 “被定义为”,例如当我们写下 x : = y x:=y x:=y 时,我们是在说明 “我们将 x x x 定义为 y y y” 或 “ x x x 等于 y y y(这是它的定义)”
此外,我们还可以找到 d i ′ d^{\prime}_i di′ 和 d i − 1 ′ d^{\prime}_{i-1} di−1′ 之间的递归关系:
d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{equation} \begin{aligned} d_{i}^{\prime} & =\sum_{j=1}^ie^{x_j-m_i} \\ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \\ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ & =d_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \end{equation} di′=j=1∑iexj−mi=(j=1∑i−1exj−mi)+exi−mi=(j=1∑i−1exj−mi−1)emi−1−mi+exi−mi=di−1′emi−1−mi+exi−mi
这个递归公式只依赖于 m i m_i mi 和 m i − 1 m_{i-1} mi−1,此外我们还可以在一个循环中同时计算 m j m_j mj 和 d j ′ d^{\prime}_j dj′
Algorithm 2-pass online softmax
for i ← 1 , N i\leftarrow 1,N i←1,N do
m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} m_i & \leftarrow \max (m_{i-1}, x_i) \\ d^{\prime}_i & \leftarrow d^{\prime}_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} midi′←max(mi−1,xi)←di−1′emi−1−mi+exi−mi
end
for i ← 1 , N i\leftarrow 1,N i←1,N do
a i ← e x i − m N d N ′ a_i \leftarrow \frac{e^{x_i - m_N}}{d^{\prime}_N} ai←dN′exi−mN
end
这是 Online Softmax 论文中提出的算法,但它仍然需要两次循环才能完成 softmax 的计算,我们能否将循环次数减少到 1 次来最小化全局的 I/O 呢?🤔
4. FlashAttention
不幸的是,对于 softmax 来说,答案是否定的,但在 Self-Attention 中,我们最终的目标并不是注意力分数矩阵 A = s o f t m a x ( Q K T ) A=\mathrm{softmax}(QK^T) A=softmax(QKT),而是输出矩阵 O = A × V O=A \times V O=A×V,既然无法找到 softmax 的一次递归形式,那我们换个思路思考下能否找到矩阵 O O O 的一次递归形式呢?
下面我们尝试下先将 Self-Attention 计算的第 k k k 行公式转化为递归算法:
Note:由于所有行的计算都是独立的,为了简单起见,这里我们只解释一行的计算
Algorithm Multi-pass Self-Attention
NOTATIONS
- Q [ k , : ] Q[k,:] Q[k,:]:查询矩阵 Q Q Q 的第 k k k 行向量
- K T [ : , i ] K^T[:,i] KT[:,i]:键矩阵 K T K^T KT 的第 i i i 列向量
- O [ k , : ] O[k,:] O[k,:]:输出矩阵 O O O 的第 k k k 行
- V [ i , : ] V[i,:] V[i,:]:值矩阵 V V V 的第 i i i 行
- o i \bm{o}_i oi: ∑ j = 1 i a j V [ j , : ] \sum_{j=1}^ia_jV[j,:] ∑j=1iajV[j,:],存储 A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:] A[k,:i]×V[:i,:] 结果的行向量
BODY
for i ← 1 , N i\leftarrow 1,N i←1,N do
x i ← Q [ k , : ] K T [ : , i ] m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} x_i & \leftarrow Q[k,:]K^T[:,i] \\ m_i & \leftarrow \max (m_{i-1},x_i) \\ d^{\prime}_i & \leftarrow d^{\prime}_{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \end{aligned} ximidi′←Q[k,:]KT[:,i]←max(mi−1,xi)←di−1′emi−1−mi+exi−mi
end
for i ← 1 , N i\leftarrow 1,N i←1,N do
a i ← e x i − m N d N ′ o i ← o i − 1 + a i V [ i , : ] \begin{align} a_i & \leftarrow \frac{e^{x_i-m_N}}{d^{\prime}_N} \\ \bm{o}_i & \leftarrow \bm{o}_{i-1} + a_iV[i,:] \end{align} aioi←dN′exi−mN←oi−1+aiV[i,:]
end
O [ k , : ] ← o N O[k,:] \leftarrow \bm{o}_N O[k,:]←oN
我们将公式 (12) 中的 a i a_i ai 替换公式 (11) 中的定义,有:
o i : = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) \begin{align} \bm{o}_i := \sum_{j=1}^i\left(\frac{e^{x_j-m_N}}{d_N^{\prime}}V[j,:]\right) \end{align} oi:=j=1∑i(dN′exj−mNV[j,:])
该公式仍然依赖 m N m_N mN 和 d N d_N dN 变量,这两个值在前一个循环完成之前无法确定。但我们可以再次使用第 3 节中的替代技巧,即创建替代序列 o ′ \bm{o}^{\prime} o′:
o i ′ : = ( ∑ j = 1 i e x j − m i d i ′ V [ j , : ] ) \bm{o}_i^{\prime}:=\left(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right) oi′:=(j=1∑idi′exj−miV[j,:])
o \bm{o} o 和 o ′ \bm{o}^{\prime} o′ 的第 N N N 个元素相同即 o N ′ = o N \bm{o}^{\prime}_N=\bm{o}_N oN′=oN ,并且我们可以发现 o i ′ \bm{o}^{\prime}_i oi′ 和 o i − 1 ′ \bm{o}^{\prime}_{i-1} oi−1′ 之间存在如下的递归关系:
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{equation} \begin{aligned} \bm{o}^{\prime}_{i}& = \sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d^{\prime}_{i}}V[j, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i}}}{d^{\prime}_{i} }V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}_{ i-1}}\frac{e^{x_{j}-m_{i}}}{e^{x_{j}-m_{i-1}}}\frac{d^{\prime}_{i-1}}{d^{ \prime}_{i}}V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \left(\sum_{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}_{ i-1}}V[j, : ]\right) \frac{d^{\prime}_{i-1}}{d^{\prime}_{i}}e^{m_{i-1}-m_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ]\\ & = \bm{o}^{\prime}_{i-1}\frac{d^{\prime}_{i-1} e^{m_{i-1}-m_{ i}}}{d^{\prime}_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ] \end{aligned} \end{equation} oi′=j=1∑idi′exj−miV[j,:]=(j=1∑i−1di′exj−miV[j,:])+di′exi−miV[i,:]=(j=1∑i−1di−1′exj−mi−1exj−mi−1exj−midi′di−1′V[j,:])+di′exi−miV[i,:]=(j=1∑i−1di−1′exj−mi−1V[j,:])di′di−1′emi−1−mi+di′exi−miV[i,:]=oi−1′di′di−1′emi−1−mi+di′exi−miV[i,:]
它仅依赖于 d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}_i,\ d^{\prime}_{i-1}, \ m_i, \ m_{i-1}, \ x_i di′, di−1′, mi, mi−1, xi ,因此我们可以在单个循环中融合 Self-Attention 中的所有计算
Algorithm FlashAttention
for i ← 1 , N i \leftarrow 1,N i←1,N do
x i ← Q [ k , : ] K T [ : , i ] m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} {x_{i}} & \leftarrow {Q[k,:]} {K^{T}[:,i]}\\ {m_{i}} & \leftarrow \max\left( {m_{i-1},x_{i}} \right)\\ {d^{\prime}_{i}} & \leftarrow {d^{\prime}_{i-1}e^{m_{i-1}-m_{i}}}+ {e^{x_{i}-m_{i}}}\\ {o^{\prime}_{i}} & \leftarrow {o^{\prime}_{i-1}}\frac{ {d^{\prime}_{i-1}e^{m_{i-1}-m_{i}}}}{ { d^{\prime}_{i}}}+\frac{ {e^{x_{i}-m_{i}}}}{ {d^{\prime}_{i}}} {V[i,:]}\end{aligned} ximidi′oi′←Q[k,:]KT[:,i]←max(mi−1,xi)←di−1′emi−1−mi+exi−mi←oi−1′di′di−1′emi−1−mi+di′exi−miV[i,:]
end
O [ k , : ] ← o N ′ O[k,:] \leftarrow \bm{o}^{\prime}_N O[k,:]←oN′
状态量 x i , m i , d i ′ , o i ′ x_i, \ m_i, \ d^{\prime}_i, \bm{o}^{\prime}_i xi, mi, di′,oi′ 占用的内存都很小,可以非常轻松的放入 GPU 的 shared memory 中。另外由于此算法中的所有操作都是关联的,因此它与 tiling 兼容,如果我们逐个 tiling 计算,则该算法可以表示为:
Algorithm FlashAttention(Tiling)
NEW NOTATIONS
- b b b:tile 的 block size 大小
- #tiles \text{\#tiles} #tiles:每行 tile 的数量, N = b × #tiles N= b \times \text{\#tiles} N=b×#tiles
- x i \bm{x}_i xi:存储第 i i i 个 tile 的 Q [ k ] K T Q[k]K^T Q[k]KT 值的向量 [ ( i − 1 ) b : i b ] [(i-1)b:ib] [(i−1)b:ib]
- m i ( l o c a l ) m_i^{\mathrm{(local)}} mi(local): x i \bm{x}_i xi 内部的局部最大值
BODY
for i ← 1 , #tile i \leftarrow 1,\text{\#tile} i←1,#tile do
x i ← Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i (local) = max j = 1 b ( x i [ j ] ) m i ← max ( m i − 1 , m i (local) ) d i ′ ← d i − 1 ′ e m i − 1 − m i + ∑ j = 1 b e x i [ j ] − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j = 1 b e x i [ j ] − m i d i ′ V [ j + ( i − 1 ) b , : ] \begin{split}\bm{x}_{i}&\ \leftarrow\ Q[k, : ] K^{T}[:,(i-1) b : i b]\\ m_{i}^{\text{(local)}}&\ =\ \max_{j=1}^{b} (\bm{x}_{i} [j]) \\ m_{i}&\ \leftarrow\ \max\big(m_{i-1},m_{i}^{\text{(local)}} \big)\\ d_{i}^{\prime}&\ \leftarrow\ d_{i-1}^{\prime} e^{m_{i-1}-m_{i}} + \sum_{j=1}^{b} e^{\bm{x}_{i}[j]-m_{i}}\\ \bm{o}_{i}^{\prime}&\ \leftarrow\ \bm{o}_{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}} + \sum_{j=1}^{b} \frac{e^{\bm{x}_{i}[j]-m_{i}}}{d_{i}^{\prime}}V[j+(i-1) b, : ]\end{split} ximi(local)midi′oi′ ← Q[k,:]KT[:,(i−1)b:ib] = j=1maxb(xi[j]) ← max(mi−1,mi(local)) ← di−1′emi−1−mi+j=1∑bexi[j]−mi ← oi−1′di′di−1′emi−1−mi+j=1∑bdi′exi[j]−miV[j+(i−1)b,:]
end
O [ k , : ] ← o N / b ′ O[k,:] \leftarrow \bm{o}^{\prime}_{N/b} O[k,:]←oN/b′
下图说明了如何将该算法应用到硬件上:
上图说明了 FlashAttention 在硬件上的计算方式,其中蓝色块表示在 SRAM 中的 tile,而红色块对应于 tile 中的第 i i i 行, L L L 表示序列长度, D D D 表示每个头的维度, B B B 表示 block size 的大小
值得注意的是,整体 SRAM 的内存占用仅仅取决于 B B B 和 D D D,与 L L L 无关,因此,该算法可以扩展到长上下文而不会遇到内存问题(GPU 的共享内存很小,例如 H100 架构中为 228kb/SM)。在计算过程中,我们从左到右遍历 K T K^T KT 和 A A A,从上到下遍历 V V V,并相应地更新 m , d , O m,\ d, \ O m, d, O 的状态
结语
这篇文章主要分享了 FlashAttention 的思想,FlashAttention 想要做的就是在单个循环中完成 self-attention 的整个计算
我们先从 safe softmax 出发分析了 safe softmax 算法需要循环迭代 3 次,第一次循环求取序列中的最大值 m N m_N mN,第二次循环求取 safe softmax 的分母 d N d_N dN,第三次循环求取最终的 softmax 值 a i a_i ai。由于 SRAM 片上内存空间有限,我们无法将所有的 x i x_i xi 缓存,导致我们在每次迭代时需要重新访问 Q , K Q,K Q,K 以动态重新计算 x i x_i xi,这造成了大量的 I/O 操作,我们并不希望看到
在 online softmax 中提出了一个有意思的想法,那就是创建一个新的序列 d i ′ d_i^{\prime} di′ 来替换原始的 d i d_i di,主要区别在于将原来 d i d_i di 定义中的 m N m_N mN 替换为了 m i m_i mi,以消除对 N N N 的依赖,并且我们发现 d N = d N ′ d_N=d_N^{\prime} dN=dN′,那么整个 safe softmax 的计算就只需要迭代 2 次即可,第一次循环迭代可以同时计算 m i m_i mi 和 d i ′ d_i^{\prime} di′
FlashAttention 想要把 self-attention 中的循环次数降低到一次,但是发现对于其中的 softmax 而言无法实现,因此它另辟蹊径试图寻找输出矩阵 O O O 的一次迭代形式。经过推导我们发现输出行向量 o i \bm{o}_i oi 仍然依赖于 m N m_N mN 和 d N d_N dN 变量,这两个值在前一个循环完成之前无法确定,因此仍需要两次循环
FlashAttention 借助 online softmax 的替代思想,将序列 o \bm{o} o 替换为 o ′ \bm{o}^{\prime} o′,替换之后发现 o i ′ \bm{o}_i^{\prime} oi′ 不再依赖于 m N m_N mN 和 d N d_N dN 等变量,而仅依赖于 d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}_i,\ d^{\prime}_{i-1}, \ m_i, \ m_{i-1}, \ x_i di′, di−1′, mi, mi−1, xi,从而可以在单个循环中融合 self-attention 中的所有计算
最后,由于 flashattention 算法中的所有操作都是关联的,因此可以逐个 tiling 分块计算,并且整体 SRAM 的内存占用与序列长度无关,因此可以扩展到长上下文而不用担心遇到内存问题
OK,以上就是关于这篇手稿的全部内容了,大家感兴趣的可以看看,作为 FlashAttention 的入门还是没有问题的😄
参考
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Gtc 2020: developing cuda kernels to push tensor cores to the absolute limit on nvidia a100
- Online normalizer calculation for softmax