Transformer 中 Self-Attention 的二次方复杂度(Quadratic Complexity )问题及改进方法:中英双语

embedded/2024/12/21 20:30:02/

Transformer 中 Self-Attention 的二次方复杂度问题及改进方法

随着大型语言模型(LLM)输入序列长度的增加,Transformer 结构中的核心模块——自注意力机制(Self-Attention) 的计算复杂度和内存消耗都呈现二次方增长。这不仅限制了模型处理长序列的能力,也成为训练和推理阶段的重要瓶颈。

本篇博客将详细解释 Transformer 中 Self-Attention 机制的二次方复杂度来源,结合代码示例展示这一问题,并介绍一些常见的改进方法。


1. Self-Attention 机制简介

原理与公式

在自注意力(Self-Attention)机制中,输入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) 被映射到三个向量:查询(Query) ( Q Q Q )、键(Key) ( K K K ) 和 值(Value) ( V V V ),三者通过权重矩阵 ( W Q W_Q WQ )、( W K W_K WK )、( W V W_V WV ) 得到:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV

自注意力输出的计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

  • ( n n n ) 是输入序列的长度(token 数量)。
  • ( d d d ) 是输入特征的维度。
  • ( d k d_k dk ) 是键向量的维度(通常 ( d k = d / h d_k = d / h dk=d/h ),其中 ( h h h ) 是多头注意力的头数)。

时间复杂度分析

从公式可以看出,自注意力机制中的关键操作是:

  1. ( Q K T Q K^T QKT ):查询向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk ) 与键向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} KRn×dk ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分数矩阵。

    • 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) )。
  2. softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩阵上进行归一化,复杂度为 ( O ( n 2 ) O(n^2) O(n2) )。

  3. 注意力分数与 ( V V V ) 相乘:将 ( n × n n \times n n×n ) 的注意力分数矩阵与 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv ) 相乘,复杂度为 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) )。

综上,自注意力机制的时间复杂度为:

O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk+n2+n2dv)O(n2d)

  • 当 ( d d d ) 是常数时,复杂度主要取决于输入序列的长度 ( n n n ),即呈二次方增长

空间复杂度分析

自注意力的注意力分数矩阵 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的内存进行存储。


2. 代码示例:计算复杂度与空间消耗

以下代码展示了输入序列长度增加时,自注意力机制的时间和空间消耗情况:

import torch
import time# 定义自注意力机制
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# 测试输入序列长度不同的时间复杂度
def test_attention_complexity():d_k = 64  # 特征维度for n in [128, 256, 512, 1024, 2048]:  # 输入序列长度Q = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

运行结果示例

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

从结果可以看出,随着序列长度的增加,计算时间呈现明显的二次方增长。


3. 二次方复杂度的改进方法

为了减少自注意力机制的计算复杂度,许多研究者提出了优化方案,主要包括:

1. 低秩近似方法

利用低秩矩阵分解减少 ( Q K T Q K^T QKT ) 的计算复杂度,例如:

  • Linformer:将 ( n × n n \times n n×n ) 的注意力矩阵通过低秩分解近似为 ( n × k n \times k n×k )(其中 ( k ≪ n k \ll n kn )),复杂度降为 ( O ( n k ) O(nk) O(nk) )。

2. 稀疏注意力(Sparse Attention)

  • LongformerBigBird:通过引入局部窗口和全局注意力机制,仅计算部分注意力分数,避免完整的 ( Q K T Q K^T QKT ) 计算,将复杂度降低为 ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。

3. 线性注意力(Linear Attention)

  • Performer:使用核技巧将自注意力计算转化为线性操作,复杂度降为 ( O ( n d ) O(n d) O(nd) )。

4. 分块方法(Blockwise Attention)

将输入序列分成多个块,仅在块内或块间进行注意力计算,适用于长序列任务。


4. 总结

在 Transformer 的自注意力机制中,由于需要计算 ( Q K T Q K^T QKT ) 和存储 ( n × n n \times n n×n ) 的注意力矩阵,其时间和空间复杂度均为 ( O ( n 2 ) O(n^2) O(n2) )。这对于处理长序列任务(如长文本、DNA 序列分析等)来说是一个显著的挑战。

为了解决这一问题,近年来提出了多种优化方法,包括低秩近似、稀疏注意力、线性注意力等,成功将复杂度从 ( O ( n 2 ) O(n^2) O(n2) ) 降低到 ( O ( n ) O(n) O(n) ) 或 ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ),从而使 Transformer 更加高效地处理长序列任务。

代码示例和实验结果清楚地展示了二次方复杂度的实际影响,同时也强调了优化方法的重要性。

英文版

The Quadratic Complexity of Self-Attention in Transformers and Possible Improvements

The core of the Transformer architecture in large language models (LLMs) is the self-attention mechanism. While it has proven revolutionary, its computational complexity and memory requirements grow quadratically as the input sequence length increases. This blog will explain the source of this quadratic complexity, demonstrate it with code, and discuss possible optimization methods.


1. Understanding Self-Attention

Mathematical Formulation

Given an input sequence ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) with ( n n n ) tokens and ( d d d ) features, the self-attention mechanism computes the query (Q), key (K), and value (V) matrices as follows:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV

The output of the self-attention mechanism is calculated as:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

Where:

  • ( n n n ): Sequence length
  • ( d d d ): Feature dimension
  • ( d k d_k dk ): Dimension of queries/keys (typically ( d k = d / h d_k = d/h dk=d/h ) for multi-head attention with ( h h h ) heads)

Time Complexity Analysis

The computational bottlenecks of self-attention are:

  1. Computing ( Q K T Q K^T QKT ):
    The query matrix ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk ) is multiplied with the transposed key matrix ( K T ∈ R d k × n K^T \in \mathbb{R}^{d_k \times n} KTRdk×n ), producing an ( n × n n \times n n×n ) attention score matrix.
    Complexity: ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ).

  2. Softmax Operation:
    Softmax normalization is applied along each row of the ( n × n n \times n n×n ) attention matrix.
    Complexity: ( O ( n 2 ) O(n^2) O(n2) ).

  3. Computing Weighted Values:
    The ( n × n n \times n n×n ) attention scores are multiplied by the value matrix ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv ).
    Complexity: ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) ).

Combining all these steps, the overall time complexity of self-attention is:

O ( n 2 d ) O(n^2 d) O(n2d)

When ( d d d ) is fixed (a constant), the complexity primarily depends on ( n n n ), making it quadratic.


Space Complexity

The attention score matrix ( Q K T Q K^T QKT ) has a size of ( n × n n \times n n×n ), requiring ( O ( n 2 ) O(n^2) O(n2) ) memory to store. This quadratic memory cost limits the model’s ability to handle long sequences.


2. Code Demonstration: Quadratic Complexity in Practice

The following code measures the computation time of self-attention as the input sequence length increases:

import torch
import time# Self-attention function
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# Test different sequence lengths
def test_attention_complexity():d_k = 64  # Feature dimensionfor n in [128, 256, 512, 1024, 2048]:  # Sequence lengthsQ = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

Example Output

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

From the output, it is clear that the computation time increases quadratically with the sequence length ( n ).


3. Solutions to Address the Quadratic Complexity

To address the inefficiency of quadratic complexity, several optimization methods have been proposed:

1. Low-Rank Approximation

Techniques like Linformer approximate the ( n × n n \times n n×n ) attention matrix using low-rank decomposition:

  • Complexity is reduced to ( O ( n k ) O(n k) O(nk) ), where ( k ≪ n k \ll n kn ).

2. Sparse Attention

Sparse attention mechanisms, such as Longformer and BigBird, compute attention only for selected tokens (e.g., local windows or global tokens):

  • Complexity is reduced to ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ) or ( O ( n ) O(n) O(n) ).

3. Linear Attention

Linear attention, such as in Performer, uses kernel functions to approximate the attention mechanism, avoiding the ( Q K T Q K^T QKT ) operation:

  • Complexity becomes ( O ( n d ) O(n d) O(nd) ).

4. Blockwise and Sliding-Window Attention

Divide the input sequence into smaller chunks or sliding windows and compute attention locally within each block:

  • This approach significantly reduces the computational cost for long sequences.

4. Summary

The self-attention mechanism in Transformer models has a time and space complexity of ( O ( n 2 d ) O(n^2 d) O(n2d)), which grows quadratically with sequence length. This becomes a bottleneck for long input sequences, such as lengthy documents or DNA sequences.

Through our code example, we demonstrated the quadratic increase in computational time as the sequence length grows. To address this limitation, several optimizations—such as low-rank approximations, sparse attention, and linear attention—have been introduced to scale Transformers to longer sequences efficiently.

By understanding and leveraging these methods, we can improve the efficiency of self-attention and unlock the potential of Transformers for applications involving extremely long sequences.

后记

2024年12月17日22点26分于上海,在GPT4o大模型辅助下完成。


http://www.ppmy.cn/embedded/147619.html

相关文章

基于xss-lab的绕过

绕过&#xff1a;闭合 "><script>alert(1)</script>< 11.2. 实体化绕过&#xff08;使用不被实体化的字符构造payload&#xff09; 使用了htmlspecialchars()函数&#xff0c;实体化一些字符&#xff0c;但默认配置不过滤单引号&#xff0c;构造单引号…

每天学习一个思维模型 - 损失规避

定义 损失规避&#xff08;Loss aversion&#xff09;&#xff0c;又称损失厌恶&#xff0c;指人们面对同样数量的利益和损失时&#xff0c;认为损失更加令他们难以忍受。损失带来的负效用为收益正效用的2至2.5倍。损失厌恶反映了人们的风险偏好并不是一致的&#xff0c;当涉及…

联邦学习防止数据泄露

文章目录 联邦学习防止数据泄露的原理联邦学习的优势联邦学习与集中式学习的成本分析联邦学习的实际应用案例个人设想参考文献 联邦学习 (Federated Learning) 是一种分布式机器学习技术&#xff0c;旨在解决数据隐私保护问题。它允许在分散的数据源上进行模型训练&#xff0c;…

目标检测任务中根据真实坐标和预测坐标计算IOU

本文记录了在目标检测任务中根据目标的真实坐标和预测坐标计算 iou 交并比指标的代码。 文章目录 一、代码 一、代码 import numpy as np import matplotlib.pyplot as plt import matplotlib.image as mpimgdef calculate_iou(real_label, predicted_label, img_width, img_h…

Docker dockerfile镜像编码 centos7

一、 大多数docker基础镜像使用locale查看编码&#xff0c;发现默认编码都是POSIX&#xff0c;这会导致中文乱码。 解决方法如下: 二、首先使用locale -a查看容器所有语言环境 三、dockerfile中加入以下参数重新生成镜像   ENV LANGen_US.UTF-8   ENV TZAsia/Shanghai  …

react 项目打包二级目 使用BrowserRouter 解决页面刷新404 找不到路由

使用BrowserRouter package 配置 &#xff08;这部分代码可以不做配置也能实现&#xff09; {"homepage": "/admin",}vite.config 配置 export default defineConfig({base: /admin])BrowserRouter 添加配置项 <BrowserRouter basename/admin>&l…

24秋:数据采集-期末复习题:选择填空判断

数据采集技术 - 复习题 题型&#xff1a;单项选择题10道&#xff0c;30分&#xff0c;多项选择题5道&#xff0c;20分&#xff0c;判断题10道&#xff0c;20分&#xff0c;填空题5道&#xff0c;20分&#xff0c;程序题2道&#xff0c;10分。 一&#xff0e;单项选择题 1、传统…

lua dofile 传参数

cat 1.lua arg[1] 111 arg[2] 222 dofile(./2.lua) cat 2.lua print("First argument is: " .. arg[1]) print("Second argument is: " .. arg[2]) 执行 lua 1.lua&#xff0c;结果为&#xff1a; First argument is: 111 Second argument is: 222 l…