【llm对话系统】大模型 Llama 源码分析之 Flash Attention

embedded/2025/2/2 15:27:08/

1. 写在前面

近年来,基于 Transformer 架构的大型语言模型 (LLM) 在自然语言处理 (NLP) 领域取得了巨大的成功。Transformer 的核心组件是自注意力 (Self-Attention) 机制,它允许模型捕捉输入序列中不同位置之间的关系。然而,标准的自注意力机制的计算复杂度与序列长度的平方成正比,这使得它在处理长序列时效率低下。

为了解决这个问题,Flash Attention 被提出,它是一种高效的注意力算法,通过利用现代 GPU 的特性,显著降低了计算复杂度和内存占用。本文将深入 Llama 源码,分析 Flash Attention 的实现逻辑,并与标准的自注意力机制进行比较。

2. Self-Attention 回顾

Self-Attention 的核心思想是:对于输入序列中的每个 token,都计算它与其他所有 token 之间的相关性,并根据这些相关性对所有 token 的表示进行加权求和,得到该 token 的新的表示。

标准的 Self-Attention 计算过程如下:

  1. 线性变换: 将输入序列的每个 token 的 embedding 向量 x_i 通过三个线性变换矩阵 W_q, W_k, W_v 映射成三个向量:q_i (query), k_i (key), v_i (value)。

    # 假设 embedding_dim = 512, seq_len = 1024
    import torch
    x = torch.randn(1, 1024, 512)  # batch_size=1, seq_len=1024, embedding_dim=512
    W_q = torch.randn(512, 512)
    W_k = torch.randn(512, 512)
    W_v = torch.randn(512, 512)q = x @ W_q  # (1, 1024, 512)
    k = x @ W_k  # (1, 1024, 512)
    v = x @ W_v  # (1, 1024, 512)
    
  2. 计算注意力分数: 计算每个 query q_i 与所有 key k_j 之间的点积,得到注意力分数 s_ij

    s = q @ k.transpose(-2, -1)  # (1, 1024, 1024)
    
  3. 缩放和掩码: 对注意力分数进行缩放 (除以 sqrt(d_k), d_k 是 key 向量的维度),并应用掩码 (mask) 操作 (例如,在解码器中屏蔽未来 token)。

    import math
    d_k = k.shape[-1]
    s = s / math.sqrt(d_k)
    # 假设我们不需要 mask
    
  4. Softmax: 对缩放后的注意力分数应用 softmax 函数,得到注意力权重 a_ij

    a = torch.softmax(s, dim=-1)  # (1, 1024, 1024)
    
  5. 加权求和: 使用注意力权重 a_ij 对所有 value 向量 v_j 进行加权求和,得到每个 token 的新的表示 y_i

    y = a @ v  # (1, 1024, 512)
    

问题: 上述计算过程中,sa 这两个矩阵的大小都是 (seq_len, seq_len),当 seq_len 很大时 (例如 4096),这两个矩阵会占用大量的显存,并且计算 softmax 和矩阵乘法也非常耗时。

3. Flash Attention 原理

Flash Attention 的核心思想是:避免将整个注意力矩阵 sa 存储在 GPU 的高速缓存 (HBM) 中,而是将输入数据分块 (tiling),每次只加载一小部分数据到 SRAM 中进行计算,并将结果写回 HBM。

Flash Attention 主要利用了以下两个技术:

3.1 Tiling (分块)

将 Q, K, V 矩阵分成多个 block,每次只计算一个 block 的注意力。例如,可以将一个 (1024, 512) 的矩阵分成 16 个 (256, 512) 的 block。

3.2 Recomputation (重计算)

在反向传播时,不存储中间的注意力权重 a,而是在需要的时候重新计算。由于计算 a 的开销相对较小,这种方法可以节省大量的显存。

4. Llama 中 Flash Attention 的实现

Llama 使用了 Flash Attention 的改进版本,即 Paged Attention。其核心思想与 Flash Attention 相同,但在处理长序列时更加高效。这里以llama2源码为例说明,其位于llama/model.py文件中,class Attention(nn.Module) 类下的forward函数中

以下是 Llama 源码中 Flash Attention 的简化版实现 (已去除部分细节):

def flash_attention(q, k, v, block_size):"""简化版的 Flash Attention 实现.Args:q: Query 矩阵 (B, H, N, D_head)k: Key 矩阵 (B, H, N, D_head)v: Value 矩阵 (B, H, N, D_head)block_size: 分块大小Returns:输出矩阵 (B, H, N, D_head)"""B, H, N, D_head = q.shapeO = torch.zeros_like(q)for i in range(0, N, block_size):# 加载当前 block 的数据到 SRAMqi = q[:, :, i:i + block_size, :]mi = -float('inf')  # 用于记录当前 block 的最大值li = 0.0  # 用于记录当前 block 的 softmax 的分母for j in range(0, N, block_size):# 加载当前 block 的数据到 SRAMkj = k[:, :, j:j + block_size, :]vj = v[:, :, j:j + block_size, :]# 计算注意力分数sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head)# 更新最大值和 softmax 的分母mij = torch.max(sij, dim=-1).valuesli_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1)# 更新输出O[:, :, i:i + block_size, :] = (li / li_new).unsqueeze(-1) * O[:, :, i:i + block_size, :] + \torch.einsum('bhnm,bhmd->bhnd', torch.exp(sij - mij.unsqueeze(-1)), vj) / li_new.unsqueeze(-1)mi = torch.max(mi, mij)li = li_newreturn O# 示例:假设 block_size = 256
q = torch.randn(1, 8, 1024, 64)  # batch_size=1, heads=8, seq_len=1024, d_head=64
k = torch.randn(1, 8, 1024, 64)
v = torch.randn(1, 8, 1024, 64)
o = flash_attention(q, k, v, block_size=256)
print(o.shape) # torch.Size([1, 8, 1024, 64])

代码解释:

  1. q, k, v 分别表示 query, key, value 矩阵, block_size 表示分块大小。
  2. O 是输出矩阵,初始化为全零。
  3. 外层循环遍历 Q 矩阵的 block。
  4. 内层循环遍历 K, V 矩阵的 block。
  5. sij 计算当前 block 的注意力分数。
  6. mili 分别用于记录当前 block 的最大值和 softmax 的分母,以保证数值稳定性。
  7. O 使用增量更新的方式计算最终的输出结果。

注意: 上述代码只是 Flash Attention 的简化版实现,实际的 Llama 源码中还包括了 mask, dropout, causal mask 等操作,并且使用了更高效的 CUDA kernel 来加速计算。

简化版实现说明

上面的代码实现了一个简化版本的Flash Attention算法。它通过两个嵌套的循环来处理查询(Q)、键(K)和值(V)矩阵,这些矩阵被分成了多个块(block)。这种分块处理的方式旨在减少计算过程中的内存占用,特别是对于那些拥有大量头的注意力机制(如多头注意力机制)来说,可以显著提高计算效率。下面我们来逐步解释这段代码的核心逻辑:

初始化输出张量 O

  • O 被初始化为与查询张量 q 相同形状的全零张量。这个张量将累积每个块计算的结果,最终形成完整的输出。

外循环:遍历Q的块

  • 代码通过 for i in range(0, N, block_size): 循环遍历 Q 矩阵的块。变量 i 表示当前处理的块在序列维度上的起始位置。

内循环:遍历K和V的块

  • 对于Q中的每个块,代码通过 for j in range(0, N, block_size): 循环遍历 K 和 V 矩阵的块。变量 j 表示 K 和 V 当前处理的块在序列维度上的起始位置。

注意力分数的计算

  • sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head) 计算当前 Q 块和 K 块之间的注意力分数。这里使用了爱因斯坦求和标记法(einsum),这是一种简洁表示张量操作的方式。

更新最大值和softmax的分母

  • mij = torch.max(sij, dim=-1).values 计算 sij 在最后一个维度上的最大值。
  • li_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1) 更新softmax的分母。这里使用了数值稳定的技巧,通过减去最大值来避免指数运算产生过大的数值。

更新输出

  • O[:, :, i:i + block_size, :] = ... 这行代码是整个算法中最关键的部分。它根据当前计算的注意力分数和值(V)来更新输出张量 O 的相应块。这里通过加权求和的方式,将之前步骤的结果累加到 O 上。

更新 mili

  • mi = torch.max(mi, mij) 更新到目前为止遇到的最大值。
  • li = li_new 更新softmax的分母。

总结

这段代码实现了一种高效的注意力机制,通过分块处理和数值稳定的softmax计算,减少了内存占用并提高了计算效率。尽管代码进行了一定的简化以突出核心逻辑,但它捕捉了Flash Attention算法的关键思想。在实际应用中,还需要考虑如何高效地在硬件上实现这些操作,以及如何处理边界情况和性能优化。

5. Flash Attention 与标准 Self-Attention 的比较

特性标准 Self-AttentionFlash Attention
计算复杂度O(N^2)O(N) (理论上, 实际取决于分块大小)
内存占用O(N^2)O(N) (理论上, 实际取决于分块大小)
速度
适用场景短序列长序列
实现复杂性简单复杂

http://www.ppmy.cn/embedded/158932.html

相关文章

DIY QMK量子键盘

最近放假了,趁这个空余在做一个分支项目,一款机械键盘,量子键盘取自固件名称QMK(Quantum Mechanical Keyboard)。 键盘作为计算机或其他电子设备的重要输入设备之一,通过将按键的物理动作转换为数字信号&am…

2024-2025自动驾驶技术演进与产业破局的深度实践——一名自动驾驶算法工程师的年度技术总结与行业洞察

一、引言:站在自动驾驶的"技术奇点" 2024年是自动驾驶行业从"技术验证"迈向"商业化落地"的关键转折点。从特斯拉FSD V12的端到端技术突破,到中国L3法规的破冰,从大模型重构感知架构,到城市NOA的&qu…

Java LongAdder 分段锁思想

专栏系列文章地址:https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标: 理解分段锁思想,了解LongAdder的原理 目录 LongAdder基本原理源码分析unsafe int 操作的一些方法Cell 对象(Striped64类下的静态内部类)add(long…

shiro学习五:使用springboot整合shiro。在前面学习四的基础上,增加shiro的缓存机制,源码讲解:认证缓存、授权缓存。

文章目录 前言1. 直接上代码最后在讲解1.1 新增的pom依赖1.2 RedisCache.java1.3 RedisCacheManager.java1.4 jwt的三个类1.5 ShiroConfig.java新增Bean 2. 源码讲解。2.1 shiro 缓存的代码流程。2.2 缓存流程2.2.1 认证和授权简述2.2.2 AuthenticatingRealm.getAuthentication…

大模型应用的10个架构挑战

[引] 在英国,时差有点乱。拾起年初的文字,迎接新春大吉! ChatGPT从正式发布到拥有1亿用户仅仅用了5天的时间,基于大型语言模型(简称大模型,或基础模型)的应用给软件行业乃至整个社会带来巨大的影…

Spring JDBC:简化数据库操作的利器

前言 Spring框架为Java开发者提供了多种技术解决方案,Spring JDBC作为其中的核心模块之一,帮助开发者更加轻松、简洁地进行数据库操作。本文将介绍Spring JDBC的概念、优势、如何使用以及常见的应用场景。 什么是Spring JDBC? Spring JDBC是…

数据库、数据仓库、数据湖有什么不同

数据库、数据仓库和数据湖是三种不同的数据存储和管理技术,它们在用途、设计目标、数据处理方式以及适用场景上存在显著差异。以下将从多个角度详细说明它们之间的区别: 1. 数据结构与存储方式 数据库: 数据库主要用于存储结构化的数据&…

1.Template Method 模式

模式定义 定义一个操作中的算法的骨架(稳定),而将一些步骤延迟(变化)到子类中。Template Method 使得子类可以不改变(复用)一个算法的结构即可重定义(override 重写)该算法的某些特…