Transformer 中 Self-Attention 的二次方复杂度问题及改进方法
随着大型语言模型(LLM)输入序列长度的增加,Transformer 结构中的核心模块——自注意力机制(Self-Attention) 的计算复杂度和内存消耗都呈现二次方增长。这不仅限制了模型处理长序列的能力,也成为训练和推理阶段的重要瓶颈。
本篇博客将详细解释 Transformer 中 Self-Attention 机制的二次方复杂度来源,结合代码示例展示这一问题,并介绍一些常见的改进方法。
1. Self-Attention 机制简介
原理与公式
在自注意力(Self-Attention)机制中,输入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d ) 被映射到三个向量:查询(Query) ( Q Q Q )、键(Key) ( K K K ) 和 值(Value) ( V V V ),三者通过权重矩阵 ( W Q W_Q WQ )、( W K W_K WK )、( W V W_V WV ) 得到:
Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV
自注意力输出的计算公式为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dkQKT)V
- ( n n n ) 是输入序列的长度(token 数量)。
- ( d d d ) 是输入特征的维度。
- ( d k d_k dk ) 是键向量的维度(通常 ( d k = d / h d_k = d / h dk=d/h ),其中 ( h h h ) 是多头注意力的头数)。
时间复杂度分析
从公式可以看出,自注意力机制中的关键操作是:
-
( Q K T Q K^T QKT ):查询向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk ) 与键向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} K∈Rn×dk ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分数矩阵。
- 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) )。
-
softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩阵上进行归一化,复杂度为 ( O ( n 2 ) O(n^2) O(n2) )。
-
注意力分数与 ( V V V ) 相乘:将 ( n × n n \times n n×n ) 的注意力分数矩阵与 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv ) 相乘,复杂度为 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) )。
综上,自注意力机制的时间复杂度为:
O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk+n2+n2dv)≈O(n2d)
- 当 ( d d d ) 是常数时,复杂度主要取决于输入序列的长度 ( n n n ),即呈二次方增长。
空间复杂度分析
自注意力的注意力分数矩阵 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的内存进行存储。
2. 代码示例:计算复杂度与空间消耗
以下代码展示了输入序列长度增加时,自注意力机制的时间和空间消耗情况:
import torch
import time# 定义自注意力机制
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# 测试输入序列长度不同的时间复杂度
def test_attention_complexity():d_k = 64 # 特征维度for n in [128, 256, 512, 1024, 2048]: # 输入序列长度Q = torch.randn((1, n, d_k)) # QueryK = torch.randn((1, n, d_k)) # KeyV = torch.randn((1, n, d_k)) # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()
运行结果示例
Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])
从结果可以看出,随着序列长度的增加,计算时间呈现明显的二次方增长。
3. 二次方复杂度的改进方法
为了减少自注意力机制的计算复杂度,许多研究者提出了优化方案,主要包括:
1. 低秩近似方法
利用低秩矩阵分解减少 ( Q K T Q K^T QKT ) 的计算复杂度,例如:
- Linformer:将 ( n × n n \times n n×n ) 的注意力矩阵通过低秩分解近似为 ( n × k n \times k n×k )(其中 ( k ≪ n k \ll n k≪n )),复杂度降为 ( O ( n k ) O(nk) O(nk) )。
2. 稀疏注意力(Sparse Attention)
- Longformer 和 BigBird:通过引入局部窗口和全局注意力机制,仅计算部分注意力分数,避免完整的 ( Q K T Q K^T QKT ) 计算,将复杂度降低为 ( O ( n log n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。
3. 线性注意力(Linear Attention)
- Performer:使用核技巧将自注意力计算转化为线性操作,复杂度降为 ( O ( n d ) O(n d) O(nd) )。
4. 分块方法(Blockwise Attention)
将输入序列分成多个块,仅在块内或块间进行注意力计算,适用于长序列任务。
4. 总结
在 Transformer 的自注意力机制中,由于需要计算 ( Q K T Q K^T QKT ) 和存储 ( n × n n \times n n×n ) 的注意力矩阵,其时间和空间复杂度均为 ( O ( n 2 ) O(n^2) O(n2) )。这对于处理长序列任务(如长文本、DNA 序列分析等)来说是一个显著的挑战。
为了解决这一问题,近年来提出了多种优化方法,包括低秩近似、稀疏注意力、线性注意力等,成功将复杂度从 ( O ( n 2 ) O(n^2) O(n2) ) 降低到 ( O ( n ) O(n) O(n) ) 或 ( O ( n log n ) O(n \log n) O(nlogn) ),从而使 Transformer 更加高效地处理长序列任务。
代码示例和实验结果清楚地展示了二次方复杂度的实际影响,同时也强调了优化方法的重要性。
英文版
The Quadratic Complexity of Self-Attention in Transformers and Possible Improvements
The core of the Transformer architecture in large language models (LLMs) is the self-attention mechanism. While it has proven revolutionary, its computational complexity and memory requirements grow quadratically as the input sequence length increases. This blog will explain the source of this quadratic complexity, demonstrate it with code, and discuss possible optimization methods.
1. Understanding Self-Attention
Mathematical Formulation
Given an input sequence ( X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d ) with ( n n n ) tokens and ( d d d ) features, the self-attention mechanism computes the query (Q), key (K), and value (V) matrices as follows:
Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV
The output of the self-attention mechanism is calculated as:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dkQKT)V
Where:
- ( n n n ): Sequence length
- ( d d d ): Feature dimension
- ( d k d_k dk ): Dimension of queries/keys (typically ( d k = d / h d_k = d/h dk=d/h ) for multi-head attention with ( h h h ) heads)
Time Complexity Analysis
The computational bottlenecks of self-attention are:
-
Computing ( Q K T Q K^T QKT ):
The query matrix ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk ) is multiplied with the transposed key matrix ( K T ∈ R d k × n K^T \in \mathbb{R}^{d_k \times n} KT∈Rdk×n ), producing an ( n × n n \times n n×n ) attention score matrix.
Complexity: ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ). -
Softmax Operation:
Softmax normalization is applied along each row of the ( n × n n \times n n×n ) attention matrix.
Complexity: ( O ( n 2 ) O(n^2) O(n2) ). -
Computing Weighted Values:
The ( n × n n \times n n×n ) attention scores are multiplied by the value matrix ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv ).
Complexity: ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) ).
Combining all these steps, the overall time complexity of self-attention is:
O ( n 2 d ) O(n^2 d) O(n2d)
When ( d d d ) is fixed (a constant), the complexity primarily depends on ( n n n ), making it quadratic.
Space Complexity
The attention score matrix ( Q K T Q K^T QKT ) has a size of ( n × n n \times n n×n ), requiring ( O ( n 2 ) O(n^2) O(n2) ) memory to store. This quadratic memory cost limits the model’s ability to handle long sequences.
2. Code Demonstration: Quadratic Complexity in Practice
The following code measures the computation time of self-attention as the input sequence length increases:
import torch
import time# Self-attention function
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# Test different sequence lengths
def test_attention_complexity():d_k = 64 # Feature dimensionfor n in [128, 256, 512, 1024, 2048]: # Sequence lengthsQ = torch.randn((1, n, d_k)) # QueryK = torch.randn((1, n, d_k)) # KeyV = torch.randn((1, n, d_k)) # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()
Example Output
Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])
From the output, it is clear that the computation time increases quadratically with the sequence length ( n ).
3. Solutions to Address the Quadratic Complexity
To address the inefficiency of quadratic complexity, several optimization methods have been proposed:
1. Low-Rank Approximation
Techniques like Linformer approximate the ( n × n n \times n n×n ) attention matrix using low-rank decomposition:
- Complexity is reduced to ( O ( n k ) O(n k) O(nk) ), where ( k ≪ n k \ll n k≪n ).
2. Sparse Attention
Sparse attention mechanisms, such as Longformer and BigBird, compute attention only for selected tokens (e.g., local windows or global tokens):
- Complexity is reduced to ( O ( n log n ) O(n \log n) O(nlogn) ) or ( O ( n ) O(n) O(n) ).
3. Linear Attention
Linear attention, such as in Performer, uses kernel functions to approximate the attention mechanism, avoiding the ( Q K T Q K^T QKT ) operation:
- Complexity becomes ( O ( n d ) O(n d) O(nd) ).
4. Blockwise and Sliding-Window Attention
Divide the input sequence into smaller chunks or sliding windows and compute attention locally within each block:
- This approach significantly reduces the computational cost for long sequences.
4. Summary
The self-attention mechanism in Transformer models has a time and space complexity of ( O ( n 2 d ) O(n^2 d) O(n2d)), which grows quadratically with sequence length. This becomes a bottleneck for long input sequences, such as lengthy documents or DNA sequences.
Through our code example, we demonstrated the quadratic increase in computational time as the sequence length grows. To address this limitation, several optimizations—such as low-rank approximations, sparse attention, and linear attention—have been introduced to scale Transformers to longer sequences efficiently.
By understanding and leveraging these methods, we can improve the efficiency of self-attention and unlock the potential of Transformers for applications involving extremely long sequences.
后记
2024年12月17日22点26分于上海,在GPT4o大模型辅助下完成。