FlashMLA(DeepSeek开源周,第一个框架):含源码分析

server/2025/3/4 12:22:36/

1. 概述

FlashMLA 是由 DeepSeek 原创开发的一种深度学习框架,专门用于加速多头注意力机制(MLA)架构的推理过程。它通过优化内存管理和计算效率,显著提升了模型在高性能 GPU 上的推理速度。FlashMLA 主要适用于 DeepSeek 的架构模型(如 DeepSeek-R1 和 DeepSeek-V3),并专为 NVIDIA H 系列显卡(如 H800 SXM5)进行了深度优化。

2.开源与代码分析

目录如下:

flash_api.cpp
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
​
#include <cutlass/fast_math.h>
​
#include "flash_mla.h"
#include "static_switch.h"
​
/*** 宏定义:CHECK_DEVICE、CHECK_SHAPE 和 CHECK_CONTIGUOUS 是用于验证输入张量的设备、形状和存储格式的便捷工具。
硬件检查:通过 at::cuda::getCurrentDeviceProperties 获取 GPU 属性,确保代码在兼容的硬件(如 SM90 架构)上运行。
张量操作:view 和 reshape 用于调整输入张量的维度,以适应不同的计算需求。
CUDA 内核调用:run_mha_fwd_splitkv_mla 是核心的 CUDA 内核函数,负责执行实际的注意力计算。
调度元数据:tile_scheduler_metadata 和 num_splits 用于控制 GPU 的计算调度,优化计算资源的分配。*/
/*** get_mla_metadata 函数
根据输入的序列长度和硬件配置(如 GPU 架构和线程块大小),计算并返回适用于多输入注意力(MLA)的调度元数据和分块信息。
主要是为了在大规模或分布式计算场景中优化 GPU 资源的分配和计算调度。
mha_fwd_kvcache_mla 函数
实现多头注意力(MHA)的前向计算,支持分块和缓存优化,特别适用于大规模模型训练中的多输入场景。
使用了各种优化手段(如峰值归一化、分组计算、调度机制等),以提高计算效率和性能。
Pybind11 绑定
将 C++ 函数绑定到 Python,使得 Python 用户能够通过 PyTorch 扩展调用高性能的 C++ 编写的注意力计算函数。*/
// 定义检查设备的宏
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
// 定义检查张量形状的宏
#define CHECK_SHAPE(x, ...) \TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// 定义检查张量是否连续的宏
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
​
// 获取多输入注意力(MLA)的元数据信息,用于调度和优化计算
std::vector<at::Tensor> get_mla_metadata(at::Tensor &seqlens_k,               // 输入序列长度张量(形状为 [batch_size])const int num_heads_per_head_k,      // 每个 Query 头对应的 Key 头的数量const int num_heads_k                // Key 头的总数量
) {// 这些常量对应 GPU 中线程块的大小(用于调度)static constexpr int block_size_m = 64;static constexpr int block_size_n = 64;static constexpr int fixed_overhead_num_blocks = 5;
​CHECK_DEVICE(seqlens_k);                     // 确保输入在 CUDA 设备上TORCH_CHECK(seqlens_k.is_contiguous());      // 确保输入是连续的TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); // 确保数据类型是 int32
​int batch_size = seqlens_k.size(0);          // 获取批处理大小int *seqlens_k_ptr = seqlens_k.data_ptr<int>(); // 获取序列长度的指针
​auto options = seqlens_k.options();          // 获取张量的选项(如数据类型和设备)
​// 获取当前 GPU 的属性auto dprops = at::cuda::getCurrentDeviceProperties();int sm_count = dprops->multiProcessorCount;  // 当前 GPU 的多处理器数量(SM 数量)
​// 计算每个 SM 分配的线程块的数量int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
​// 初始化存储调度元数据和分块数的张量auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);auto num_splits = torch::empty({batch_size + 1}, options);
​int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();int *num_splits_ptr = num_splits.data_ptr<int>();
​// 确保设备和 CUDA 流一致at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};auto stream = at::cuda::getCurrentCUDAStream().stream();
​// 设置调度参数Mla_metadata_params params = {};params.seqlens_k_ptr = seqlens_k_ptr;                // 序列长度指针params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; // 元数据指针params.num_splits_ptr = num_splits_ptr;              // 分块数指针params.batch_size = batch_size;                      // 批处理大小params.block_size_n = block_size_n;                  // 线程块大小(n 方向)params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; // 固定开销的线程块数量params.num_sm_parts = num_sm_parts;                  // SM 的划分数量
​// 调用 CUDA 函数计算元数据get_mla_metadata_func(params, stream);
​return {tile_scheduler_metadata, num_splits}; // 返回调度元数据和分块数
}
​
// 多头注意力(MHA)前向计算函数,使用分块和缓存优化(适用大模型训练中的多输入场景)
std::vector<at::Tensor> mha_fwd_kvcache_mla(at::Tensor &q,                               // 查询张量(形状为 [batch_size, seqlen_q, num_heads, head_size])const at::Tensor &kcache,                    // Key 缓存(形状为 [num_blocks, page_block_size, num_heads_k, head_size])c10::optional<const at::Tensor> &vcache_,    // Value 缓存(形状为 [num_blocks, page_block_size, num_heads_k, head_size_v])const int head_size_v,                       // Value 的头大小const at::Tensor &seqlens_k,                 // Key 的序列长度(形状为 [batch_size])const at::Tensor &block_table,               // 分块表(形状为 [batch_size, max_num_blocks_per_seq])const float softmax_scale,                   // Softmax 缩放因子bool is_causal,                              // 是否启用因果机制(避免未来时间步的干扰)const at::Tensor &tile_scheduler_metadata,   // 调度元数据(形状为 [num_sm_parts, TileSchedulerMetaDataSize])const at::Tensor &num_splits                 // 分块数(形状为 [batch_size +1])
) {auto dprops = at::cuda::getCurrentDeviceProperties();bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // 检查是否是 SM90 架构的 GPU
​TORCH_CHECK(is_sm90); // 必须是 SM90 架构的 GPU
​at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; // 如果 Value 缓存存在,使用它;否则使用 Key 缓存
​// 检查数据类型一致性auto q_dtype = q.dtype();TORCH_CHECK(q_dtype == torch::kBFloat16);TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
​CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); // 确保输入在 CUDA 设备上
​// 确保张量是行优先存储的(即最后一个维度是连续的)TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
​CHECK_DEVICE(block_table); // 确保分块表在 CUDA 设备上TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
​// 获取输入张量的大小const auto sizes = q.sizes();const int batch_size = sizes[0];const int seqlen_q_ori = sizes[1];const int num_heads_ori = sizes[2];const int head_size = sizes[3];
​TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
​const int max_num_blocks_per_seq = block_table.size(1);const int num_blocks = kcache.size(0);const int page_block_size = kcache.size(1);const int num_heads_k = kcache.size(2);
​TORCH_CHECK(batch_size > 0, "batch size must be postive");TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
​if (seqlen_q_ori == 1) { is_causal = false; } // 单个时间步时不启用因果机制
​const int ngroups = num_heads_ori / num_heads_k;  // 分组数量const int seqlen_q = seqlen_q_ori * ngroups;      // 调整后的查询序列长度const int num_heads = num_heads_k;                // 注意力头的数量
​// 调整输入张量的形状q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3).reshape({batch_size, seqlen_q, num_heads, head_size});
​int head_size_k = head_size;
​// 检查张量形状CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
​TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");CHECK_DEVICE(seqlens_k);CHECK_CONTIGUOUS(seqlens_k);CHECK_SHAPE(seqlens_k, batch_size);
​at::cuda::CUDAGuard device_guard{(char)q.get_device()}; // 确保设备与输入一致
​// 初始化输出张量auto opts = q.options();at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
​// 设置注意力参数Flash_fwd_mla_params params = {};params.b = batch_size;                          // 批处理大小params.seqlen_q = seqlen_q;                     // 查询序列长度params.cu_seqlens_k = seqlens_k.data_ptr<int>();// Key 的累积序列长度params.h = num_heads;                           // 注意力头数量params.h_h_k_ratio = num_heads / num_heads_k;   // 注意力头的比例params.ngroups = ngroups;                       // 分组数量params.is_causal = is_causal;                   // 是否是因果注意力params.d = head_size;                           // Query/Key 的头大小params.d_v = head_size_v;                       // Value 的头大小params.scale_softmax = softmax_scale;           // Softmax 的缩放因子params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Softmax 的缩放因子(对数形式)
​// 设置输入和输出张量的指针和 Strideparams.q_ptr = q.data_ptr();                    // 查询指针params.k_ptr = kcache.data_ptr();               // Key 指针params.v_ptr = vcache.data_ptr();               // Value 指针params.o_ptr = out.data_ptr();                  // 输出指针params.softmax_lse_ptr = softmax_lse.data_ptr(); // Softmax 最终值的指针
​params.q_batch_stride = q.stride(0);            // 查询的批处理 Strideparams.k_batch_stride = kcache.stride(0);       // Key 的批处理 Strideparams.v_batch_stride = vcache.stride(0);       // Value 的批处理 Strideparams.o_batch_stride = out.stride(0);          // 输出的批处理 Stride
​params.q_row_stride = q.stride(-3);             // 查询的行 Strideparams.k_row_stride = kcache.stride(-3);        // Key 的行 Strideparams.v_row_stride = vcache.stride(-3);        // Value 的行 Strideparams.o_row_stride = out.stride(-3);           // 输出的行 Stride
​params.q_head_stride = q.stride(-2);            // 查询的头 Strideparams.k_head_stride = kcache.stride(-2);       // Key 的头 Strideparams.v_head_stride = vcache.stride(-2);       // Value 的头 Strideparams.o_head_stride = out.stride(-2);          // 输出的头 Stride
​params.block_table = block_table.data_ptr<int>(); // 分块表的指针params.block_table_batch_stride = block_table.stride(0); // 分块表的批处理 Strideparams.page_block_size = page_block_size;         // 分块的大小
​// 检查调度元数据和分块数的参数TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);CHECK_DEVICE(tile_scheduler_metadata);CHECK_CONTIGUOUS(tile_scheduler_metadata);params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();params.num_sm_parts = tile_scheduler_metadata.size(0);
​TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");CHECK_DEVICE(num_splits);CHECK_CONTIGUOUS(num_splits);params.num_splits_ptr = num_splits.data_ptr<int>();
​// 初始化累加变量at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();params.oaccum_ptr = out_accum.data_ptr();
​// 获取当前 CUDA 流auto stream = at::cuda::getCurrentCUDAStream().stream();
​TORCH_CHECK(head_size == 576); // 确保头大小满足特定条件
​// 调用核心注意力计算函数(CUDA 内核)run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
​// 重新调整输出张量的形状out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3).reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
​softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3).reshape({batch_size, num_heads_ori, seqlen_q_ori});
​return {out, softmax_lse}; // 返回注意力输出和 Softmax 最终值
}
​
// 使用 Pybind11 将 C++ 函数绑定到 Python 扩展
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.doc() = "FlashMLA"; // 扩展的描述信息
​// 绑定函数到 Pythonm.def("get_mla_metadata", &get_mla_metadata);   // 绑定元数据函数m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); // 绑定注意力前向计算函数
}
flash_fwd_mla_bf16_sm90.cu
#include "flash_fwd_mla_kernel.h"
​
/*** #include "flash_fwd_mla_kernel.h"
这是一个头文件包含指令,用于引入 run_mha_fwd_splitkv_mla 函数的定义或实现。
flash_fwd_mla_kernel.h 文件可能包含函数的声明、相关数据类型和常量的定义等。
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>
这是显式实例化一个模板函数,run_mha_fwd_splitkv_mla 是一个模板函数。
cutlass::bfloat16_t 是模板的类型参数,表示使用的数据类型(bfloat16 浮点格式)。
576 是模板的非类型参数,通常表示某种尺寸或配置。
Flash_fwd_mla_params &params
这是一个参数结构体,通常包含多个成员变量,用以传递函数所需的计算参数。
它可能包含输入张量(如 Query、Key、Value)、输出张量、硬件配置(如线程块大小)以及其他计算相关的元数据。
cudaStream_t stream
这是一个 CUDA 流的标识符,用于指定在 GPU 上执行的内核函数的流。
CUDA 流允许开发者控制 GPU 内核的执行顺序和同步,不同的流中的操作可能会并行执行。
run_mha_fwd_splitkv_mla
这是一个多头注意力(MHA)的前向计算函数,具体实现可能包含以下操作:
计算注意力权重(使用 Query 和 Key)。
应用 Softmax 函数归一化权重。
将权重与 Value 相乘,得到注意力输出。*/
​
/*** 该代码片段将一个基于模板的多头注意力计算函数实例化为一个具体的 CUDA 函数,用于在 GPU 上执行。
它需要预定义的参数结构体(Flash_fwd_mla_params)和 CUDA 流(stream)作为输入。
通过显式实例化模板函数,确保编译器生成特定于 bfloat16 数据类型和参数 576 的优化代码。*/
// 显式实例化模板函数,将其绑定到特定的数据类型和参数
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, // 传递参数,包括张量、硬件配置和其他计算信息cudaStream_t stream            // 指定 CUDA 流,用于控制内核的执行顺序和同步
);
flash_fwd_mla_kernel.h
#pragma once
​
#include <cute/tensor.hpp>
#include <cutlass/cutlash.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
​
using namespace cute;
​
#include "named_barrier.h"
#include "utils.h"
#include "softmax.h"
#include "static_switch.h"
#include "flash_mla.h"
/*** 模板结构体:Flash_fwd_kernel_traits_mla 定义了多头注意力前向计算的模板参数和数据布局,包括线程块大小、头大小、数据类型等。张量化矩阵乘法:make_tiled_mma 用于定义张量化的矩阵乘法操作,支持不同布局和数据类型的矩阵乘法。共享内存布局:SmemLayoutQ 和 SmemLayoutK 定义了共享内存中查询和键张量的布局,确保数据对齐和高效访问。CUDA 内核函数:flash_fwd_splitkv_mla_kernel 和 flash_fwd_splitkv_mla_combine_kernel 分别负责多头注意力的前向计算和分块结果的合并。因果机制:Is_causal 参数用于启用或禁用因果机制,确保在计算时不会泄露未来的序列信息。Softmax 归一化:flash::Softmax 用于计算 Softmax 归一化,确保注意力权重的稳定性。分片合并:通过共享内存和全局内存存储累积结果,并最终将结果合并到输出张量中。* */
​
// 根据头大小选择共享内存(SMEM)的布局
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {// 根据头大小对齐要求返回相应的 SMEM 布局constexpr int headSizeBytes = sizeof(PrecType) * DIM;constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
​if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {return GMMA::Layout_K_SW128_Atom<PrecType>{};  // 128 字节对齐} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {return GMMA::Layout_K_SW64_Atom<PrecType>{};   // 64 字节对齐} else {return GMMA::Layout_K_SW32_Atom<PrecType>{};   // 32 字节对齐}
}
​
// 定义多头注意力(MHA)的前向计算内核的模板结构体
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {using Element = elem_type;                 // 数据类型(如 bfloat16)using ElementAccum = float;                // 累加数据类型(如 float)using index_t = int64_t;                   // 索引类型
​static constexpr int kNWarps = kNWarps_;    // 每个线程块的 warp 数量static constexpr int kNThreads = kNWarps * 32;  // 每个线程块的线程数量static constexpr int kNWarpsS = 4;          // SMEM 累加部分的 warp 数量static constexpr int kNThreadsS = kNWarpsS * 32; // SMEM 累加部分的线程数量
​static constexpr int kBlockM = kBlockM_;    // 线程块大小(行)static constexpr int kBlockN = kBlockN_;    // 线程块大小(列)static constexpr int kHeadDim = kHeadDim_;  // 查询/键的头大小static_assert(kHeadDim % 32 == 0);          // 确保头大小是 32 的倍数
​static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; // 值的头大小static_assert(kHeadDimV % 32 == 0);         // 确保值头大小是 32 的倍数static_assert(kHeadDimV <= kHeadDim);       // 值头大小不能大于查询/键的头大小static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; // SMEM 的块大小static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;      // 数据重排方式
​// 定义张量化矩阵乘法(Tiled MMA)操作using TiledMma = decltype(make_tiled_mma(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,GMMA::Major::K, GMMA::Major::K>(),Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
​// 定义累加器的部分static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;using TiledMmaO = decltype(make_tiled_mma(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,GMMA::Major::K, GMMA::Major::MN>(),Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
​// 定义 SMEM 的布局using SmemLayoutQ = decltype(tile_to_shape(getSmemLayoutK<Element, kHeadDim>(), Shape<Int<kBlockM>, Int<kHeadDim>>{}));
​using SmemLayoutK = decltype(tile_to_shape(getSmemLayoutK<Element, kHeadDim, kHeadDimV>(), Shape<Int<kBlockN>, Int<kHeadDim>>{}));
​using SmemLayoutV = decltype(tile_to_shape(getSmemLayoutK<Element, kHeadDim, kHeadDimV>(), Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
​using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
​using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
​// 定义 SMEM 累加数据的布局using SmemLayoutAtomO = decltype(composition(Swizzle<kSwizzle, 3, 3>{}, Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
​// 定义 SMEM 的拷贝操作using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
​// 定义张量加载和存储的配置static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
​using GmemLayoutAtom = Layout<Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,Stride<Int<kGmemThreadsPerRow>, _1>>;using GmemTiledCopy = decltype(make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},GmemLayoutAtom{},Layout<Shape<_1, _8>>{}));  // 每次加载 8 个元素
​using GmemLayoutAtomO = Layout<Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,Stride<Int<kGmemThreadsPerRow>, _1>>;using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},GmemLayoutAtomO{},Layout<Shape<_1, _8>>{}));  // 每次存储 8 个元素
​static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;using GmemLayoutAtomOaccum = Layout<Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,Stride<Int<kGmemThreadsPerRowAccum>, _1>>;using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},GmemLayoutAtomOaccum{},Layout<Shape<_1, _4>>{}));  // 每次存储 4 个累加元素
};
​
namespace flash {
​
using namespace cute;
​
// 定义共享内存(Shared Storage)的结构
template<typename Kernel_traits>
struct SharedStorageMLA {union {struct {// 查询、键和累积值的 SMEM 存储cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k;  // 双缓冲cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;};struct {// 注意力权重的最大值、总和和累加结果的 SMEM 存储cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;};};
};
​​
// 定义在完成计算后存储结果的函数模板
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {// 模板参数和数据类型的定义constexpr int kBlockM = Kernel_traits::kBlockM;constexpr int kHeadDimV = Kernel_traits::kHeadDimV;constexpr int kNThreadsS = Kernel_traits::kNThreadsS;using Element = typename Kernel_traits::Element;using ElementAccum = typename Kernel_traits::ElementAccum;using index_t = typename Kernel_traits::index_t;
​const int tidx = threadIdx.x;
​// 获取 Tiled MMA 操作的线程分片typename Kernel_traits::TiledMmaO tiled_mma_o;auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
​// 进行 softmax 归一化Tensor lse = softmax.template normalize_softmax_lse<!Split, Split>(tOrO, params.scale_softmax);
​// 确定输出数据的数据类型using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
​// 获取输出张量的 SMEM 布局Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{});
​// 获取 SMEM 拷贝操作的线程分片using SmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum>;auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
​// 对分片张量进行拷贝Tensor rO = flash::convert_type<ElementO>(tOrO);Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);
​__syncthreads();cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
​// 定义偏移量const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;const index_t row_offset_oaccum = (((Split ? params.num_splits_ptr : NULL) + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;const index_t row_offset_lseaccum = (((Split ? params.num_splits_ptr : NULL) + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
​// 定义输出张量Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + row_offset_o), Shape<Int<kBlockM>, Int<kHeadDimV>>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lse), Shape<Int<kBlockM>>{}, Stride<_1>{});
​// 获取 GMEM 拷贝操作的线程分片using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;GmemTiledCopyO gmem_tiled_copy_Oaccum;auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
​__syncthreads();
​if (tidx >= kNThreadsS) { return; }
​// 拷贝数据到输出张量Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
​// 计算输出张量的布局Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{});Tensor taccOcO = thr_mma_o.partition_C(caccO);Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);if (get<1>(taccOcO_row(0)) == 0) {
#pragma unrollfor (int mi = 0; mi < size(taccOcO_row); ++mi) {const int row = get<0>(taccOcO_row(mi));if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }}}
​// 将输出张量存储到全局内存Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum)));Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
​flash::copy<!Split, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM);
}
​
// 定义计算注意力权重的函数模板
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,const int bidb, const int bidh, const int m_block,const int n_split_idx, const int seqlen_k,const int n_block_min, const int n_block_max, const bool NoSplit,SharedStorage &shared_storage) {// 模板参数和数据类型的定义constexpr int kBlockM = Kernel_traits::kBlockM;constexpr int kBlockN = Kernel_traits::kBlockN;constexpr int kHeadDim = Kernel_traits::kHeadDim;constexpr int kHeadDimV = Kernel_traits::kHeadDimV;constexpr int kNThreads = Kernel_traits::kNThreads;constexpr int kNThreadsS = Kernel_traits::kNThreadsS;static_assert(kNThreads == 256 && kNThreadsS == 128);using Element = typename Kernel_traits::Element;using index_t = typename Kernel_traits::index_t;
​const int tidx = threadIdx.x;int n_block = n_block_max - 1;
​// 定义 SMEM 张量的布局Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
​Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
​// 定义 Tiled MMA 操作typename Kernel_traits::TiledMmaO tiled_mma_o;auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt);Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});clear(tOrO);
​// 定义 Softmax 对象flash::Softmax<2 * size<1>(tOrO)> softmax;
​int warp_group_idx = cutlass::canonical_warp_group_idx();if (warp_group_idx == 0) {// 获取 Tiled MMA 操作的线程分片typename Kernel_traits::TiledMma tiled_mma;auto thr_mma = tiled_mma.get_thread_slice(tidx);Tensor tSrQ = thr_mma.partition_fragment_A(sQ);Tensor tSrK = thr_mma.partition_fragment_B(sK);
​if (n_block % 2 == 1) {// 对 SMEM 进行双缓冲constexpr int sK_offset = size(sK);tSrK.data() = tSrK.data() + sK_offset / 8;tOrVt.data() = tOrVt.data() + sK_offset / 8;}
​// 定义掩码步数constexpr int n_masking_steps = !Is_causal ? 1 : static_cast<int>(ceil_div(kBlockM, kBlockN)) + 1;
​// 循环处理掩码步for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {__syncthreads();
​// 执行矩阵乘法Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});flash::gemm<! true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
​bool is_masking_step = masking_step > 0;bool is_first_masking_step = masking_step == n_masking_steps;
​if (is_masking_step) {Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});Tensor tScS = thr_mma.partition_C(cS);
​// 根据条件设置掩码if constexpr (!Is_causal) {for (int i = 0; i < size(tSrS); ++i) {if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) {tSrS(i) = -INFINITY;}}} else {for (int i = 0; i < size(tSrS); ++i) {int row = int(get<0>(tScS(i)));int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;if (int(get<1>(tScS(i))) > col_limit_right) {tSrS(i) = -INFINITY;}}}}
​// 计算 Softmax 和缩放因子Tensor scale_o = is_first_masking_step? softmax.template softmax<! true, Is_causal>(tSrS, params.scale_softmax_log2): is_masking_step ? softmax.template softmax<! false, Is_causal>(tSrS, params.scale_softmax_log2): softmax.template softmax<! false, false>(tSrS, params.scale_softmax_log2);
​// 将结果存储到 SMEMTensor rP = flash::convert_type<Element>(tSrS);cute::copy(rP, tPsP);cute::copy(scale_o, tScale_osScale_o);
​cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
​// 更新输出张量flash::rescale_o(tOrO, scale_o);
​// 执行矩阵乘法,计算输出Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));flash::gemm<! false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
​// 更新 SMEM 的数据指针const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);tSrK.data() = tSrK.data() + sK_offset / 8;tOrVt.data() = tOrVt.data() + sK_offset / 8;}
​// 将 Softmax 的最大值和总和存储到 SMEMcute::copy(softmax.row_max, tRow_maxsRow_max);cute::copy(softmax.row_sum, tRow_sumsRow_sum);cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));} else {// 获取分块表的指针const int *block_table = params.block_table + bidb * params.block_table_batch_stride;int cur_block_table = __ldg(&block_table[n_block]);
​// 定义 GMEM 张量的布局const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_stride(params.q_row_stride, _1{}));
​typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
​// 将查询数据加载到 SMEMflash::copy<! true, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM);
​// 定义 GMEM 张量的布局const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), Shape<Int<kBlockN>, Int<kHeadDim>>{}, make_stride(params.k_row_stride, _1{}));
​typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);Tensor tKgK = gmem_thr_copy_K.partition_S(gK);Tensor tKsK = gmem_thr_copy_K.partition_D(sK);Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));Tensor tKcK = gmem_thr_copy_K.partition_S(cK);Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
​// 对 SMEM 进行双缓冲if (n_block % 2 == 1) {constexpr int sK_offset = size(sK);tKsK.data() = tKsK.data() + sK_offset;tOrVt.data() = tOrVt.data() + sK_offset / 8;}
​// 将键数据加载到 SMEMconst index_t offset_k = cur_block_table * params.k_batch_stride;tKgK.data() = tKgK.data() + offset_k;flash::copy<! true, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, seqlen_k - n_block * kBlockN);tKgK.data() = tKgK.data() - offset_k;cute::cp_async_fence();
​// 循环处理分块for (; n_block >= n_block_min; --n_block) {flash::cp_async_wait<0>();__syncthreads();
​if (n_block - 1 >= n_block_min) {// 对 SMEM 进行双缓冲const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);tKsK.data() = tKsK.data() + sK_offset;
​// 加载下一个分块的数据cur_block_table = __ldg(&block_table[n_block - 1]);const index_t offset_k = cur_block_table * params.k_batch_stride;tKgK.data() = tKgK.data() + offset_k;flash::copy<! true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);tKgK.data() = tKgK.data() - offset_k;cute::cp_async_fence();}
​// 同步线程cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
​// 获取更新后的分块数据if (n_block - 2 >= n_block_min) {cur_block_table = __ldg(&block_table[n_block - 2]);}
​// 获取 Tiled MMA 操作的线程分片typename Kernel_traits::TiledMma tiled_mma;auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});Tensor rP = make_tensor<Element>(tSrS_layout.layout());Tensor scale_o = make_tensor<float>(Shape<_2>{});cute::copy(tScale_osScale_o, scale_o);cute::copy(tPsP, rP);
​// 更新输出张量flash::rescale_o(tOrO, scale_o);Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));flash::gemm<! false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
​// 更新 SMEM 的数据指针const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);tOrVt.data() = tOrVt.data() + sK_offset / 8;}
​// 同步线程cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));cute::copy(tRow_maxsRow_max, softmax.row_max);cute::copy(tRow_sumsRow_sum, softmax.row_sum);}
​// 根据条件存储结果if (NoSplit) {store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);} else {store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);}
}
​
// 定义 CUDA 内核函数模板,负责多头注意力的前向计算
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {// 模板参数的定义constexpr int kBlockN = Kernel_traits::kBlockN;const int m_block = blockIdx.x;const int bidh = blockIdx.y;const int partition_idx = blockIdx.z;
​extern __shared__ char shared_memory[];auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
​// 获取调度元数据int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));int begin_idx = tile_scheduler_metadata.x;int begin_seqlen = tile_scheduler_metadata.y;int end_idx = tile_scheduler_metadata.z;int end_seqlen = tile_scheduler_metadata.w;if (begin_idx >= params.b) return;int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
​// 循环处理批量数据for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
​if (batch_id > begin_idx) {__syncthreads(); // 同步线程}
​// 调用计算注意力权重的函数模板flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);}
}
​
// 定义将分块结果合并的 CUDA 内核函数模板
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
__global__ void __launch_bounds__(256, 1, 1)
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {// 线程和块的配置constexpr int kNThreads = 128;const int tidx = threadIdx.x;const int bidx = blockIdx.x;const int hs = params.h * params.seqlen_q;const int batch_idx = bidx / hs;const int hs_idx = bidx % hs;
​// 获取分块信息const int split_offset = __ldg(params.num_splits_ptr + batch_idx);const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);if (actual_num_splits == 1) return;
​// 定义共享内存数组__shared__ ElementAccum sLseScale[kMaxSplits];
​// 定义全局存储的 LSE 累积张量const index_t row_offset_lseaccum = split_offset * hs + hs_idx;Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum), Shape<Int<kMaxSplits>>{}, make_stride(hs));
​// 定义全局存储的 LSE 张量Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + bidx), Shape<_1>{}, Stride<_1>{});
​int warp_idx = cutlass::canonical_warp_idx_sync();if (warp_idx == 0) {// 每个线程加载多个 LSE 值constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);float local_lse[kNLsePerThread];
​for (int i = 0; i < kNLsePerThread; ++i) {const int split = i * 32 + tidx;local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;}
​// 在线程内计算最大值float max_lse = -INFINITY;for (int i = 0; i < kNLsePerThread; ++i) {max_lse = max(max_lse, local_lse[i]);}
​// 在 warp 内同步计算全局最大值for (int offset = 16; offset >= 1; offset /= 2) {max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));}
​// 处理无效情况max_lse = max_lse == -INFINITY ? 0.0f : max_lse;
​// 计算总和float sum_lse = 0;for (int i = 0; i < kNLsePerThread; ++i) {sum_lse += expf(local_lse[i] - max_lse);}
​// 在 warp 内同步计算全局总和for (int offset = 16; offset >= 1; offset /= 2) {sum_lse += __shfl_xor_sync(uint32_t(-1), sum_lse, offset);}
​// 计算全局 LSEfloat global_lse = sum_lse == 0.f || isnan(sum_lse) ? INFINITY : logf(sum_lse) + max_lse;if (tidx == 0) {gLSE(0) = global_lse;}
​// 将缩放因子存储到共享内存for (int i = 0; i < kNLsePerThread; ++i) {const int split = i * 32 + tidx;if (split < actual_num_splits) {sLseScale[split] = expf(local_lse[i] - global_lse);}}}__syncthreads();
​// 计算输出张量的布局static_assert(kHeadDimV % kNThreads == 0);constexpr int Elements = kHeadDimV / kNThreads;const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
​// 定义 GMEM 张量的布局Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum), Shape<Int<kHeadDimV>>{}, Stride<_1>{});
​// 定义 GMEM 拷贝操作的线程分片using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},Layout<Shape<Int<kNThreads>>>{},Layout<Shape<Int<Elements>>>{}));GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
​// 定义张量Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));clear(tOrO);
​// 循环处理每个分块for (int split = 0; split < actual_num_splits; ++split) {cute::copy(tOgOaccum, tOrOaccum);ElementAccum lse_scale = sLseScale[split];for (int i = 0; i < size(tOrO); ++i) {tOrO(i) += lse_scale * tOrOaccum(i);}tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;}
​// 将结果存储到全局内存Tensor rO = flash::convert_type<Element>(tOrO);const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<Elements>>{}, make_stride(_1()));cute::copy(rO, gO);
}
​
} // namespace flash
​
// 定义运行 Flash 分块前向计算的功能函数
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {// 参数校验FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
​// 根据因果机制选择内核函数BOOL_SWITCH(params.is_causal, Is_causal, [&] {// 获取 CUDA 内核函数auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
​// 设置内核函数的共享内存大小constexpr size_t smem_size = sizeof(SharedStorage);CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
​// 调用 CUDA 内核函数kernel<<<dim3(ceil_div(params.seqlen_q, Kernel_traits::kBlockM), params.h, params.num_sm_parts),Kernel_traits::kNThreads, smem_size, stream>>>(params);});
​// 检查 CUDA 内核函数的执行结果CHECK_CUDA_KERNEL_LAUNCH();
​// 调用 CUDA 内核函数,合并分块结果dim3 grid_combine(params.b * params.h * params.seqlen_q);MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;combine_kernel<<<grid_combine, 128, 0, stream>>>(params);});
​// 检查 CUDA 内核函数的执行结果CHECK_CUDA_KERNEL_LAUNCH();
}
​
// 定义运行 MHA 的分块前向计算的功能函数
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {static_assert(Headdim == 576); // 确保头大小为 576FLASH_ASSERT(params.d_v == 512); // 确保值头大小为 512FLASH_ASSERT(params.k_ptr == params.v_ptr); // 确保键和值共享同一张量
​using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; // 定义模板结构体run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}
​
// 定义生成 MLA 调度元数据的功能函数
static constexpr int MaxBatchSize = 4096;
​
__global__ void __launch_bounds__(256, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {int *seqlens_k_ptr = params.seqlens_k_ptr;int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;int *num_splits_ptr = params.num_splits_ptr;int batch_size = params.batch_size;int block_size_n = params.block_size_n;int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;int num_sm_parts = params.num_sm_parts;
​// 定义共享内存数组__shared__ int num_blocks_shared[MaxBatchSize];__shared__ int num_splits_shared[MaxBatchSize];
​// 计算总数int total_num_blocks = 0;for (int i = threadIdx.x; i < batch_size; i += 32) {int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);total_num_blocks += num_blocks + fixed_overhead_num_blocks;num_blocks_shared[i] = num_blocks;}
​// 使用同步指令进行线程间的数据聚合for (int offset = 16; offset >= 1; offset /= 2) {total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);}__syncwarp();
​// 计算调度元数据if (threadIdx.x == 0) {int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
​int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;num_splits_shared[0] = 0;
​for (int i = 0; i < num_sm_parts; ++i) {int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
​tile_scheduler_metadata0[0] = now_idx;tile_scheduler_metadata0[1] = now_block * block_size_n;tile_scheduler_metadata1 = now_n_split_idx;
​int remain_payload = payload;
​while (now_idx < batch_size) {int num_blocks = num_blocks_shared[now_idx];int now_remain_blocks = num_blocks - now_block;
​if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {cum_num_splits += now_n_split_idx + 1;num_splits_shared[now_idx + 1] = cum_num_splits;remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;++now_idx;now_block = 0;now_n_split_idx = 0;} else {if (remain_payload - fixed_overhead_num_blocks > 0) {now_block += remain_payload - fixed_overhead_num_blocks;++now_n_split_idx;remain_payload = 0;}break;}}
​tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
​FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);}}__syncwarp();
​// 将分块数量存储到全局内存for (int i = threadIdx.x; i <= batch_size; i += 32) {num_splits_ptr[i] = num_splits_shared[i];}
}
​
// 定义生成 MLA 调度元数据的功能函数
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {FLASH_ASSERT(params.batch_size < MaxBatchSize); // 确保批量大小在最大值范围内
​// 调用 CUDA 内核函数get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);CHECK_CUDA_KERNEL_LAUNCH();
}
flash_mla.h
#pragma once/*** 参数结构体:Flash_fwd_mla_params 和 Mla_metadata_params 是通过指针和索引传递参数的关键结构体,用于在 GPU 上进行高效的内存访问和计算调度。模板函数:run_mha_fwd_splitkv_mla 是一个模板函数,允许针对不同的数据类型和头大小进行优化。调度元数据:TileSchedulerMetaDataSize 定义了调度元数据的大小,用于控制 GPU 的计算调度。*//**  主要功能: 在 GPU 上实现高效的多头注意力计算* Flash_fwd_mla_params 结构体定义了多头注意力(MHA)前向计算所需的参数,包括:输入张量(查询、键、值)和输出张量的指针。张量的维度信息(如批量大小、序列长度、头大小)。
计算相关的超参数(如 Softmax 缩放因子、因果机制标志)。调度和分块信息(如分块表、元数据指针)。
run_mha_fwd_splitkv_mla 函数模板调用多头注意力的前向计算功能,支持分块和缓存优化。通过模板参数 T 和 Headdim,可以针对不同的数据类型和头大小进行优化。
Mla_metadata_params 结构体定义了生成多线程计算元数据所需的参数,包括:序列长度数组分块表和分块数指针。SM 的划分数量和固定开销的块数量。
get_mla_metadata_func 函数生成多线程计算的元数据,用于优化 GPU 上的调度和计算。*/
// 定义多头注意力(MHA)前向计算的参数结构体
struct Flash_fwd_mla_params {using index_t = int64_t; // 索引类型// 参数描述:// b: 批量大小// seqlen_q: 查询序列长度// d: 查询/键的头大小// d_v: 值的头大小// h: 注意力头数量// h_h_k_ratio: 注意力头的比例// ngroups: 分组数量// is_causal: 是否启用因果机制// scale_softmax: Softmax 缩放因子// scale_softmax_log2: Softmax 缩放因子(以 2 为底的对数形式)// cu_seqlens_k: 累积序列长度数组(指向 GPU 内存)int b, seqlen_q, d, d_v;int h, h_h_k_ratio, ngroups;bool is_causal;float scale_softmax, scale_softmax_log2;int *__restrict__ cu_seqlens_k;// 张量指针(指向 GPU 内存)void *__restrict__ q_ptr;       // 查询张量void *__restrict__ k_ptr;       // 键张量void *__restrict__ v_ptr;       // 值张量void *__restrict__ o_ptr;       // 输出张量void *__restrict__ softmax_lse_ptr; // Softmax 最终值张量// 张量的 Stride(布局)信息index_t q_batch_stride;         // 查询张量的批量 Strideindex_t k_batch_stride;         // 键张量的批量 Strideindex_t v_batch_stride;         // 值张量的批量 Strideindex_t o_batch_stride;         // 输出张量的批量 Strideindex_t q_row_stride;           // 查询张量的行 Strideindex_t k_row_stride;           // 键张量的行 Strideindex_t v_row_stride;           // 值张量的行 Strideindex_t o_row_stride;           // 输出张量的行 Strideindex_t q_head_stride;          // 查询张量的头 Strideindex_t k_head_stride;          // 键张量的头 Strideindex_t v_head_stride;          // 值张量的头 Strideindex_t o_head_stride;          // 输出张量的头 Stride// 分块表相关参数int *__restrict__ block_table;      // 分块表指针index_t block_table_batch_stride;   // 分块表的批量 Strideint page_block_size;                // 分块大小// 调度元数据相关参数int *__restrict__ tile_scheduler_metadata_ptr; // 调度元数据指针int num_sm_parts;                              // SM 的划分数量int *__restrict__ num_splits_ptr;              // 分块数指针// 累加值相关指针void *__restrict__ softmax_lseaccum_ptr;       // Softmax 最终值累加指针void *__restrict__ oaccum_ptr;                 // 输出累加指针
};// 调度元数据的大小(单位:字节)
static constexpr int TileSchedulerMetaDataSize = 8;
// 元数据的格式为:[begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]// 定义运行多头注意力(MHA)的前向计算功能函数模板
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);// 定义生成多线程计算元数据的结构体
struct Mla_metadata_params {// 序列长度数组(指向 GPU 内存)int *__restrict__ seqlens_k_ptr;// 调度元数据指针int *__restrict__ tile_scheduler_metadata_ptr;// 分块数指针int *__restrict__ num_splits_ptr;// 批量大小int batch_size;// 块大小(N 方向)int block_size_n;// 固定开销的块数量int fixed_overhead_num_blocks;// SM 的划分数量int num_sm_parts;
};// 定义生成多线程计算元数据的功能函数
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);named_barrier.h
#pragma once#include "cutlass/barrier.h" // 引入 Cutlass 库中的屏障(Barrier)功能
/*** 主要功能:多线程计算中提供一种同步机制,确保不同阶段的计算按顺序执行,避免数据竞争和冲突。*//*** cutlass/barrier.h:引入 Cutlass 库中的屏障功能,用于多线程同步。命名屏障:通过枚举类定义命名屏障,便于在代码中明确标识不同阶段的同步点。*//*** 命名屏障(NamedBarriers)定义了两个命名屏障,用于在多线程计算中同步不同阶段的执行。SReady:表示某个阶段(如数据加载或计算)已准备就绪。SoftmaxReady:表示 Softmax 计算已准备就绪。避免冲突通过为屏障命名,避免了不同计算阶段之间的潜在冲突,确保多线程计算的正确性和高效性。*/
namespace flash {// 定义命名屏障的枚举类,用于避免潜在的冲突
enum class NamedBarriers {SReady = 1,          // 表示某个阶段已准备就绪的屏障SoftmaxReady = 2,    // 表示 Softmax 计算已准备就绪的屏障
};} // namespace flash
softmax.h
#pragma once#include <cmath>#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>#include "utils.h"/*** 主要功能:为 Flash-Attention 提供高效的计算和归一化操作,针对大规模模型训练进行了优化。*//*** 线程内/多线程 reduce 操作:thread_reduce_ 和 quad_allreduce_ 分别实现了线程内和多线程间的 reduce 操作,用于将张量按行聚合为一维摘要张量。reduce_ 综合了线程内 reduce 和多线程 quad 全规约操作,用于高效的 GPU 并行计算。
张量变换与归一化:scale_apply_exp2 定义了对张量应用指数量化,通过指数量变换实现精确的归一化。max_scale_exp2_sum 综合了求最大值、指数变换和求和操作,用于计算 Softmax 归一化的中间结果。
Softmax 归一化:rescale_o 定义了对输出张量进行归一化的函数,确保归一化因子被正确应用。Softmax 模板类实现了对输入张量进行 Softmax 归一化的功能,包括求最大值、指数变换、求和和归一化因子计算等步骤。
误差与优化:代码中对 NaN 和无效值进行了处理,避免了计算过程中的错误。使用了 Allreduce 和 UNFUSE_FMA 等优化策略,提高了计算性*/namespace flash {using namespace cute;// 定义线程内 reduce 操作,将张量按行聚合为一维摘要张量
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保摘要张量是一维的CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));  // 确保摘要张量的行数与输入张量的行数一致// 按行对张量进行聚合for (int mi = 0; mi < size<0>(tensor); mi++) {summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));for (int ni = 1; ni < size<1>(tensor); ni++) {summary(mi) = op(summary(mi), tensor(mi, ni));}}
}// 定义多线程间的 quad 全规约操作,确保不同线程中的数据一致
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {CUTE_STATIC_ASSERT_V(size(dst) == size(src));  // 确保输入和输出张量的大小一致for (int i = 0; i < size(dst); i++){dst(i) = Allreduce<4>::run(src(i), op);     // 使用 Allreduce 操作进行数据规约}
}// 定义综合 reduce 操作,结合线程内 reduce 和多线程 quad 全规约
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {thread_reduce_<zero_init>(tensor, summary, op); // 线程内 reduce 操作quad_allreduce_(summary, summary, op);          // 线程间 quad 全规约操作
}// 定义求最大值的 reduce 操作
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){MaxOp<float> max_op;                            // 定义最大值操作reduce_<zero_init>(tensor, max, max_op);        // 使用 reduce_ 操作求最大值
}// 定义求和的 reduce 操作
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){SumOp<float> sum_op;                            // 定义求和操作thread_reduce_<zero_init>(tensor, sum, sum_op); // 使用线程内 reduce 操作求和
}// 定义对张量应用指数量化
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保最大值张量是一维的CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));       // 确保最大值张量的行数与输入张量的行数一致// 按行对张量进行量级调整和指数变换for (int mi = 0; mi < size<0>(tensor); ++mi) {const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));for (int ni = 0; ni < size<1>(tensor); ++ni)  {#ifdef UNFUSE_FMAtensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);#elsetensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);#endif}}return tensor;
}// 定义综合 max-scale-exp-sum 操作,包括求最大值、指数变换和求和
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保最大值和求和张量是一维的CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));       // 确保最大值张量的行数与输入张量的行数一致// 按行对张量进行 max-scale-exp-sum 操作for (int mi = 0; mi < size<0>(tensor); ++mi) {MaxOp<float> max_op;                            // 定义最大值操作max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));for (int ni = 1; ni < size<1>(tensor); ni++) {max(mi) = max_op(max(mi), tensor(mi, ni));}max(mi) = Allreduce<4>::run(max(mi), max_op);   // 使用 Allreduce 操作进行数据规约const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;sum(mi) = 0;for (int ni = 0; ni < size<1>(tensor); ++ni)  {tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); // 应用指数变换sum(mi) += tensor(mi, ni);                                  // 求和}SumOp<float> sum_op;                                            // 定义求和操作sum(mi) = Allreduce<4>::run(sum(mi), sum_op);                   // 使用 Allreduce 操作进行数据规约}
}// 定义对输出张量进行归一化的函数
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {// 转换 acc_o 的布局为行优先Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));for (int mi = 0; mi < size(scale_o); ++mi) {for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); // 应用归一化因子}}
}// 定义用于计算 Softmax 的模板类
template <int kNRows>
struct Softmax {using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{})); // 定义摘要张量类型TensorT row_max, row_sum;                                           // 行最大值和行和__forceinline__ __device__ Softmax() {}                             // 默认构造函数// 定义对输入张量进行 Softmax 操作并返回缩放因子template<bool Is_first, bool Check_inf=false, typename Tensor0>__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {// 转换张量布局为行优先Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));static_assert(decltype(size<0>(scores))::value == kNRows);      // 确保输入张量符合布局要求TensorT scale_o;                                                // 初始化缩放因子clear(scale_o);if (Is_first) {// 如果是第一次迭代,计算行最大值flash::template reduce_max</*zero_init=*/true>(scores, row_max);// 应用指数变换flash::scale_apply_exp2<scales_max>::apply(scores, row_max, softmax_scale_log2);// 计算行和flash::reduce_sum</*zero_init=*/true>(scores, row_sum);} else {// 否则,使用缓存的行最大值继续计算Tensor scores_max_prev = make_fragment_like(row_max);cute::copy(row_max, scores_max_prev);flash::template reduce_max</*zero_init=*/false>(scores, row_max);#pragma unrollfor (int mi = 0; mi < size(row_max); ++mi) {float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);scale_o(mi) = scores_scale;row_sum(mi) *= scores_scale;}flash::scale_apply_exp2<scales_max>::apply(scores, row_max, softmax_scale_log2);flash::reduce_sum</*zero_init=*/false>(scores, row_sum);}return scale_o;}// 定义对输入张量进行 Softmax 归一化并返回最终值template<bool Is_dropout=false, bool Split=false, typename Tensor0>__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {SumOp<float> sum_op;                                            // 定义求和操作quad_allreduce_(row_sum, row_sum, sum_op);                      // 使用多线程 Quad Allreduce 操作进行数据规约TensorT lse = make_fragment_like(row_sum);                      // 初始化累加最终值Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); // 确保输入张量符合布局要求// 根据行和计算累加最终值for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {float sum = row_sum(mi);float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; // 应用归一化因子}}return lse;}
};}  // namespace flash
static_switch.h
#pragma once
/*** 主要功能:为 Flash-Attention 提供基本的工具宏,用于错误检查、断言和条件编译。*//*** 错误检查CHECK_CUDA 和 CHECK_CUDA_KERNEL_LAUNCH:这两个宏用于检查 CUDA 调用是否成功,确保 GPU 操作的安全性和正确性。CHECK_CUDA 用于显式检查 CUDA 函数调用的结果,而 CHECK_CUDA_KERNEL_LAUNCH 则在 CUDA 内核启动后检查是否有错误发生。
断言FLASH_ASSERT 和 FLASH_DEVICE_ASSERT:这两个宏用于在代码中插入断言,确保程序状态符合预期。FLASH_ASSERT 用于主机代码,FLASH_DEVICE_ASSERT 用于 GPU 设备代码,当断言失败时,程序会终止并打印错误信息。
条件编译BOOL_SWITCH 和 MLA_NUM_SPLITS_SWITCH:这两个宏用于在编译时或运行时根据条件动态选择代码路径。BOOL_SWITCH 根据布尔条件选择不同的实现,而 MLA_NUM_SPLITS_SWITCH 根据分块数量选择不同的实现,从而优化 GPU 的计算性能。*/
// 定义检查 CUDA 调用是否成功的宏
#define CHECK_CUDA(call)                                                                                  \do {                                                                                                  \cudaError_t status_ = call;     // 调用 CUDA 函数或操作if (status_ != cudaSuccess) {   // 如果调用失败fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); // 打印错误信息exit(1);                      // 退出程序}                                                                                                 \} while(0) // 确保宏的行为类似于一条语句// 定义检查 CUDA 内核启动是否成功的宏
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) // 检查上一次 CUDA 内核调用是否有错误// 定义运行时断言宏,用于在 CPU 代码中检查条件是否成立
#define FLASH_ASSERT(cond)                                                                                \do {                                                                                                  \if (not (cond)) {                                                                                 \fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond);                 // 打印错误信息exit(1);                                                                                      // 退出程序}                                                                                                 \} while(0)// 定义设备端断言宏,用于在 GPU 代码中检查条件是否成立
#define FLASH_DEVICE_ASSERT(cond)                                                                         \do {                                                                                                  \if (not (cond)) {                                                                                 \printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond);                          // 打印错误信息asm("trap;");                                                                                 // 触发陷阱指令,终止程序}                                                                                                 \} while(0)// 定义布尔条件编译宏,用于根据布尔条件选择不同的代码路径
#define BOOL_SWITCH(COND, CONST_NAME, ...)      \[&] {                                         \if (COND) {                                 \constexpr static bool CONST_NAME = true;  // 如果条件成立,设置 CONST_NAME 为 truereturn __VA_ARGS__();                     // 执行相应的代码} else {                                    \constexpr static bool CONST_NAME = false; // 如果条件不成立,设置 CONST_NAME 为 falsereturn __VA_ARGS__();                     // 执行相应的代码}                                           \}()// 定义根据分块数量选择不同实现的宏,用于动态选择最优的计算策略
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \[&] {                                              \if (NUM_SPLITS <= 32) {                          \constexpr static int NAME = 32;                // 分块数量小于等于 32 时,选择对应实现return __VA_ARGS__();                          \} else if (NUM_SPLITS <= 64) {                   \constexpr static int NAME = 64;                // 分块数量在 33-64 之间时,选择对应实现return __VA_ARGS__();                          \} else if (NUM_SPLITS <= 96) {                   \constexpr static int NAME = 96;                // 分块数量在 65-96 之间时,选择对应实现return __VA_ARGS__();                          \} else if (NUM_SPLITS <= 128) {                  \constexpr static int NAME = 128;               // 分块数量在 97-128 之间时,选择对应实现return __VA_ARGS__();                          \} else if (NUM_SPLITS <= 160) {                  \constexpr static int NAME = 160;               // 分块数量在 129-160 之间时,选择对应实现return __VA_ARGS__();                          \} else {                                         \FLASH_ASSERT(false);                           // 分块数量超过 160 时,触发断言错误}                                                \}()
utils.h
// 从 https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h 适配而来#pragma once#include <assert.h>
#include <stdint.h>
#include <stdlib.h>#include <cuda_bf16.h>  // 包含 NVIDIA CUDA 的半精度浮点数支持#include <cute/tensor.hpp>  // 包含 CUTLASS 的张量操作库
#include <cutlass/array.h>  // 包含 CUTLASS 的数组操作
#include <cutlass/cutlass.h> // 包含 CUTLASS 的基础功能
#include <cutlass/numeric_conversion.h> // 包含 CUTLASS 的数值转换功能
#include <cutlass/numeric_types.h>     // 包含 CUTLASS 的数值类型定义/*** 加速和优化深度学习模型中的某些底层操作,尤其是矩阵运算和数据处理步骤,能够显著提高 GPU 上模型训练和推理的性能。*//*** 矩阵运算优化:提供了自定义的规约操作(Allreduce),能够高效地在 Warp 内进行数据的规约(如最大值或求和操作),可用于加速矩阵运算中的某些步骤。实现了一个高效的矩阵乘法(gemm)函数,通过手动展开循环和控制累加器的初始化,能够更好地利用 GPU 的特性和 CUTLASS 的功能,提高矩阵乘法的性能。
张量布局转换:提供了两组函数(convert_layout_acc_rowcol 和 convert_layout_acc_Aregs)用于将张量的布局从一种形式转换为另一种形式,以适应不同架构(如 SM80 和 SM90)的要求,确保数据在 GPU 上能够更高效地存储和访问。
异步内存操作优化:提供了一个异步内存操作等待函数(cp_async_wait),能够精确地控制异步内存操作的等待时间,减少不必要的等待开销,提高 GPU 的利用率。
数据类型转换:提供了一个模板函数(convert_type),能够将张量的数据类型进行转换,支持不同的数值类型(如 float、half 等)之间的转换,增强代码的通用性。
张量拷贝控制:提供了一个张量拷贝函数(copy),允许开发者通过布尔参数精确控制拷贝操作的行为,如是否清除超出边界的元素、是否跳过某些依赖项等,提高拷贝效率和灵活性。
最大值和加法运算:提供了自定义的最大值和加法运算模板(MaxOp 和 SumOp),能够处理各种数据类型(包括 float),并针对特定类型进行了优化,提高运算速度。*/namespace flash { // 定义 flash 命名空间template<typename T>
struct MaxOp { // 定义一个通用的最大值运算模板__device__ __forceinline__ T operator()(T const & x, T const & y) { // 设备端函数,实现二元运算return x > y ? x : y; // 返回两个值中的最大值}
};// 为 float 类型特化 MaxOp,使用硬件优化的 max 函数
template <>
struct MaxOp<float> {__device__ __forceinline__ float operator()(float const &x, float const &y) {return max(x, y); // 更高效地使用硬件 max 指令}
};template<typename T>
struct SumOp { // 定义一个通用的加法运算模板__device__ __forceinline__ T operator()(T const & x, T const & y) { // 设备端函数,实现加法return x + y; // 返回两个值的和}
};template<int THREADS>
struct Allreduce { // 继续进这个结构体是用于 Warp 级规约的
// Allreduce 模板结构体,用于跨线程块的 reducing 操作static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, // 线程数必须是 32/16/8/4"THREADS must be 32, 16, 8, or 4"); template<typename T, typename Operator> // T 是数据类型,Operator 是运算类型static __device__ __forceinline__ T run(T x, Operator &op) { // 运行规约操作constexpr int OFFSET = THREADS / 2; // 定义线程块的偏移量x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); // 使用 shuffle 操作进行规约return Allreduce<OFFSET>::run(x, op); // 递归调用,不断缩小线程块大小}
};template<> // 特化 Allreduce<2>,因为递归终止条件是线程块大小为 2
struct Allreduce<2> {template<typename T, typename Operator>static __device__ __forceinline__ T run(T x, Operator &op) {x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); // 偏移量为 1,进行最后一次规约return x;}
};// 定义一个通用的 GEMM 函数,用于矩阵乘法
template<bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;// 需要移除 tCrA 的 const 属性,因为 warpgroup_fence_operand 不支持 const 张量if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }// 对输出张量应用 warpgroup_fence_operandwarpgroup_fence_operand(tCrC);// 根据是否需要同步线程组,执行到来操作if constexpr (arrive) { warpgroup_arrive(); }// 如果需要初始化为零,设置累加器为 Zero,并手动展开 K 维度的循环if constexpr (zero_init) {tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // 初始化累加器为零CUTLASS_PRAGMA_UNROLLfor (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // 执行 GEMMtiled_mma.accumulate_ = GMMA::ScaleOut::One; // 累加操作设置为 One}} else {// 默认情况下,手动展开 K 维度的循环CUTLASS_PRAGMA_UNROLLfor (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);tiled_mma.accumulate_ = GMMA::ScaleOut::One;}}// 根据是否需要提交,执行提交操作if constexpr (commit) {warpgroup_commit_batch();}// 等待指定数量的线程组if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }// 对输出张量应用 query 逻辑warpgroup_fence_operand(tCrC);if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}// 对于不同架构(SM80 和 SM90),将 acc_layout 转换为行/列布局
// 转换规则:将 3D 布局转换为二维布局
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {// 判断是否为 SM90 架构if constexpr (decltype(rank<0>(acc_layout))::value == 3) {static_assert(decltype(size<0, 0>(acc_layout))::value == 2);static_assert(decltype(size<0, 1>(acc_layout))::value == 2);static_assert(decltype(rank(acc_layout))::value == 3);auto l = acc_layout;if constexpr (!Transposed) {return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));} else {return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));}} else {static_assert(decltype(size<0>(acc_layout))::value == 4);static_assert(decltype(rank(acc_layout))::value == 3);auto l = logical_divide(acc_layout, Shape<_2>{});if constexpr (!Transposed) {return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));} else {return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));}}
}// 转换光照布局为特定的 Aregs 布局
// 对于不同架构(SM80 和 SM90)和数据类型(FP16/BF16 和 FP8),进行不同的转换
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {using X = Underscore;if constexpr (decltype(rank<0>(acc_layout))::value == 3) {// SM90 架构的转换逻辑static_assert(decltype(size<0, 0>(acc_layout))::value == 2);static_assert(decltype(size<0, 1>(acc_layout))::value == 2);static_assert(decltype(rank(acc_layout))::value == 3);static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{});return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)),get<1>(acc_layout),coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));} else {static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>> {});return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),get<1>(acc_layout),coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));}} else {// SM80 架构的转换逻辑static_assert(decltype(size<0>(acc_layout))::value == 4);static_assert(decltype(rank(acc_layout))::value == 3);constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK {});if constexpr (mma_shape_K == 8) {return acc_layout;} else {auto l = logical_divide(acc_layout, Shape<X, X, _2> {});return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));}}
}// 转换数据类型的模板函数
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {using From_type = typename Engine::value_type; // 获取源数据类型constexpr int numel = decltype(size(tensor))::value; // 获取元素数量// 使用 CUTLASS 的数值转换器进行类型转换cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}// 异步内存操作等待函数,用于优化异步操作的等待时间
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); // 使用内联汇编调用 cp.async.wait_group
#endif
}// 张量拷贝函数,用于优化矩阵拷贝操作
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {// 静态断言,确保张量的维度一致CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); // 确保参数不会冲突// 循环遍历张量的维度,进行拷贝操作#pragma unrollfor (int m = 0; m < size<1>(S); ++m) {if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {// 满足条件时执行拷贝#pragma unrollfor (int k = 0; k < size<2>(S); ++k) {if (Is_even_K || predicate_K(k)) {cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); // 执行拷贝} else if (Clear_OOB_K) {cute::clear(D(_, m, k)); // 清除超出边界的元素}}} else if (Clear_OOB_MN) {cute::clear(D(_, m, _)); // 清除超出边界的行或列}}
}} // namespace flash
init.py
python">##导入用于机器学习加速(Flash Machine Learning Acceleration,Flash MLA)的接口函数
__version__ = "1.0.0"  # 定义项目的版本号
# 从 flash_mla.flash_mla_interface 模块中导入两个函数
from flash_mla.flash_mla_interface import (get_mla_metadata,  # 获取机器学习加速(MLA)的元数据,可能包含模型设置、超参数等信息flash_mla_with_kvcache  # 使用键值缓存(KV Cache)的机器学习加速函数,用于优化模型推理或训练过程中的数据缓存
)
flash_mla_interface.py
python">#  通过高效的 GPU 编程实现大规模注意力机制的计算,大幅度提升深度学习模型的推理和训练效率
from typing import Optional, Tuple  # 导入类型提示相关的模块import torch  # 导入 PyTorch,用于张量操作和深度学习功能import flash_mla_cuda  # 导入自定义的 CUDA 扩展模块,用于实现 Flash MLA(机器学习加速)功能def get_mla_metadata(  # 获取 Flash MLA 的元数据cache_seqlens: torch.Tensor,  # 缓存的序列长度,shape 为 (batch_size),dtype torch.int32num_heads_per_head_k: int,  # 每个 K 头对应的查询头数量,等于 seq_len_q * num_heads_q // num_heads_knum_heads_k: int,  # K 头的数量
) -> Tuple[torch.Tensor, torch.Tensor]:  # 返回两个张量"""Function to retrieve metadata for Flash Machine Learning Acceleration (MLA).Args:cache_seqlens (torch.Tensor): A 1D tensor of shape (batch_size) containing cache sequence lengths.num_heads_per_head_k (int): Number of query heads per K head (equals seq_len_q * num_heads_q // num_heads_k).num_heads_k (int): Number of K heads.Returns:tile_scheduler_metadata (torch.Tensor): Metadata for tile scheduling, shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.num_splits (torch.Tensor): Array of splits for partitioning, shape (batch_size + 1), dtype torch.int32."""return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)  # 调用 CUDA 函数获取元数据def flash_mla_with_kvcache(  # 使用键值缓存(KV Cache)的 Flash MLA 函数q: torch.Tensor,  # 查询张量,shape (batch_size, seq_len_q, num_heads_q, head_dim)k_cache: torch.Tensor,  # 键缓存,shape (num_blocks, page_block_size, num_heads_k, head_dim)block_table: torch.Tensor,  # 块表格,shape (batch_size, max_num_blocks_per_seq), dtype torch.int32cache_seqlens: torch.Tensor,  # 缓存的序列长度,shape (batch_size), dtype torch.int32head_dim_v: int,  # V 头的维度tile_scheduler_metadata: torch.Tensor,  # 瓦片调度元数据,shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32num_splits: torch.Tensor,  # 分割数组,shape (batch_size + 1), dtype torch.int32softmax_scale: Optional[float] = None,  # softmax 的缩放因子,默认为 1 / sqrt(head_dim)causal: bool = False,  # 是否应用因果掩码
) -> Tuple[torch.Tensor, torch.Tensor]:  # 返回两个张量"""Flash Machine Learning Acceleration with KV Cache for efficient attention computation.Args:q (torch.Tensor): Query tensor of shape (batch_size, seq_len_q, num_heads_q, head_dim).k_cache (torch.Tensor): Key cache tensor of shape (num_blocks, page_block_size, num_heads_k, head_dim).block_table (torch.Tensor): Block table tensor of shape (batch_size, max_num_blocks_per_seq), dtype torch.int32.cache_seqlens (torch.Tensor): Cache sequence lengths tensor of shape (batch_size), dtype torch.int32.head_dim_v (int): Head dimension for value.tile_scheduler_metadata (torch.Tensor): Tile scheduler metadata tensor, shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.num_splits (torch.Tensor): Array of splits for partitioning, shape (batch_size + 1), dtype torch.int32.softmax_scale (float, optional): Scaling factor for softmax. Defaults to 1 / sqrt(head_dim).causal (bool, optional): Whether to apply causal attention mask. Defaults to False.Returns:out (torch.Tensor): Output tensor of shape (batch_size, seq_len_q, num_heads_q, head_dim_v).softmax_lse (torch.Tensor): LogSumExp values for softmax, shape (batch_size, num_heads_q, seq_len_q), dtype torch.float32."""if softmax_scale is None:  # 如果 softmax_scale 未指定,计算默认值softmax_scale = q.shape[-1] ** (-0.5)  # 默认缩放因子为 1 / sqrt(head_dim)# 调用 CUDA 函数执行前向计算out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(  q,  # 查询张量k_cache,  # 键缓存None,  # 未使用head_dim_v,  # V 头的维度cache_seqlens,  # 缓存的序列长度block_table,  # 块表格softmax_scale,  # softmax 缩放因子causal,  # 是否应用因果掩码tile_scheduler_metadata,  # 瓦片调度元数据num_splits,  # 分割数组)return out, softmax_lse  # 返回输出张量和 softmax 的 LogSumExp 向量
setup.py
python">#构建和分发一个包含 CUDA 扩展模块的 Python 包,该扩展模块利用了 CUTLASS 库和 Flash Machine Learning Acceleration(Flash MLA)技术来加速深度学习模型中的矩阵运算和注意力机制计算。
import os  # 导入操作系统相关模块,用于获取环境变量和操作目录
from pathlib import Path  # 用于处理文件路径
from datetime import datetime  # 用于获取当前时间
import subprocess  # 用于调用外部命令
from setuptools import setup, find_packages  # 导入 setuptools,用于打包和分发 Python 包
from torch.utils.cpp_extension import (BuildExtension,  # 用于构建 C++/CUDA 扩展的构建扩展类CUDAExtension,  # 用于定义 CUDA 扩展IS_WINDOWS,  # 是否是 Windows 系统
)def append_nvcc_threads(nvcc_extra_args):  # 添加多线程支持"""Append NVIDIA CUDA Compiler (nvcc) threads setting to the compilation arguments.Args:nvcc_extra_args (list): Existing nvcc compilation arguments.Returns:list: Updated nvcc compilation arguments with thread settings."""nvcc_threads = os.getenv("NVCC_THREADS") or "32"  # 获取环境变量中的线程数,默认为 32return nvcc_extra_args + ["--threads", nvcc_threads]  # 添加线程设置到编译参数# 更新 Git 子模块 CUTLASS
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])  # 初始化和更新 Git 子模块 CUTLASS,用于获得最新的 CUTLASS 库代码# 定义 CUDA 编译架构
cc_flag = []  # 初始化 CUDA 架构标志列表
cc_flag.append("-gencode")  # 添加代码生成标志
cc_flag.append("arch=compute_90a,code=sm_90a")  # 指定目标架构为计算架构 90a 和架构 90athis_dir = os.path.dirname(os.path.abspath(__file__))  # 获取当前脚本所在的目录# 定义 C++ 编译参数
if IS_WINDOWS:cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]  # Windows 系统下的 C++ 编译参数
else:cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]  # 非 Windows 系统下的 C++ 编译参数# 定义 CUDA 扩展模块
ext_modules = []
ext_modules.append(CUDAExtension(  # 定义 CUDA 扩展模块name="flash_mla_cuda",  # 扩展模块的名称sources=[  # C++ 和 CUDA 源文件路径"csrc/flash_api.cpp","csrc/flash_fwd_mla_bf16_sm90.cu",],extra_compile_args={  # 额外的编译参数"cxx": cxx_args,  # C++ 编译参数"nvcc": append_nvcc_threads(  # CUDA 编译参数["-O3","-std=c++17","-DNDEBUG","-D_USE_MATH_DEFINES","-Wno-deprecated-declarations","-U__CUDA_NO_HALF_OPERATORS__","-U__CUDA_NO_HALF_CONVERSIONS__","-U__CUDA_NO_HALF2_OPERATORS__","-U__CUDA_NO_BFLOAT16_CONVERSIONS__","--expt-relaxed-constexpr","--expt-extended-lambda","--use_fast_math","--ptxas-options=-v,--register-usage-level=10"]+ cc_flag  # 添加 CUDA 架构参数),},include_dirs=[  # 包含目录,指定额外的头文件路径Path(this_dir) / "csrc",Path(this_dir) / "csrc" / "cutlass" / "include",],)
)# 获取 Git 提交哈希或生成时间版本号
try:cmd = ['git', 'rev-parse', '--short', 'HEAD']  # 获取 Git 提交哈希的命令rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()  # 执行命令并获取提交哈希
except Exception as _:now = datetime.now()  # 获取当前时间date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")  # 格式化时间字符串rev = '+' + date_time_str  # 生成基于时间的版本号# 配置打包信息
setup(name="flash_mla",  # 包的名称version="1.0.0" + rev,  # 包的版本号packages=find_packages(include=['flash_mla']),  # 包含的 Python 模块ext_modules=ext_modules,  # 包含的扩展模块cmdclass={"build_ext": BuildExtension},  # 构建类,用于构建 C++/CUDA 扩展
)

3. FlashMLA 的技术亮点

3.1 内存优化

FlashMLA 通过以下技术优化内存使用:

  • 分页 KV 缓存:采用块大小为 64 的分页式显存管理,避免了传统连续内存分配导致的显存碎片化问题。这一设计使得单卡能够并行处理超过 200 个对话线程,服务密度提升 3 倍。

  • BF16 精度支持:FlashMLA 支持 BF16 数据格式,显存占用减少 50%。结合低秩压缩技术,KV 缓存体积可压缩至原体积的 1/4,实现 93.3% 的 KV 缓存量削减。

3.2 性能提升

FlashMLA 在 NVIDIA H800 SXM5 GPU 上表现出卓越的性能:

  • 内存带宽:在内存受限的配置下,FlashMLA 能够实现 3000 GB/s 的带宽。

  • 计算性能:在计算受限的配置下,FlashMLA 能够达到 580 TFLOPS 的计算性能。

  • 性能对比:与传统模型推理框架(如 PyTorch)相比,FlashMLA 的性能提升了 10 倍;与 FlashAttention 相比,性能提升了 3-5 倍

3.3 其他优化

FlashMLA 的设计灵感来源于 FlashAttention 2&3 和英伟达的 CUTLASS 项目。它通过以下方式进一步提升推理效率:

  • 线性复杂度设计:FlashMLA 采用线性复杂度的设计,避免了传统 MHA 的二次复杂度瓶颈。

  • 优化 KV 缓存机制:通过高效的 KV 缓存管理,FlashMLA 实现了更高的推理吞吐量和更低的延迟。

4. 应用场景和优势(涵盖硬件)

FlashMLA 适用场景:
  1. 长序列处理:适合处理数千个标记的文本,如文档分析、长对话等。

  2. 实时应用:如聊天机器人、虚拟助手、实时翻译系统等,显著降低延迟,提升用户体验。

  3. 资源效率优化:减少内存和计算需求,便于在边缘设备或资源受限的环境中部署。

硬件适配与成本优势

FlashMLA 目前仅适配 DeepSeek 的架构模型(如 DeepSeek-R1 和 DeepSeek-V3),并专为 NVIDIA H 系列显卡(如 H800 SXM5)进行了优化。通过优化内存和计算效率,FlashMLA 显著降低了服务器的硬件成本。例如:

  • 推理成本降低:单位推理成本大幅降低,使得 AI 公司和云计算服务商能够在相同的 GPU 资源下处理更多请求。

  • 硬件资源优化:通过减少显存占用和提升计算效率,FlashMLA 使得企业能够在有限的硬件资源下实现更高的推理吞吐量。


http://www.ppmy.cn/server/172328.html

相关文章

手机投屏电脑 Scrcpy

Scrcpy scrcpy github link 手机均需在开发模式下&#xff0c;系统与更新 》开发人员选线 》开启USB调试 、 开启“仅充电”模式下允许ADB调试 1.帮助 .\scrcpy.exe -h2.数据线连接 .\scrcpy.exe -d 3.wifi连接&#xff08;先插上数据线使用命令连接后&#xff0c;再拔…

【Java从入门到起飞】面向对象编程(基础)

文章目录 1. static关键字1.1 概述1.2 定义格式和使用1.2.1 静态变量及其访问1.2.2 实例变量及其访问1.2.3 静态方法及其访问1.2.4 实例方法及其访问 1.3 小结 2. 继承2.1 概述2.1.1 引入2.1.2 继承的含义2.1.3 继承的好处 2.2 继承的格式2.3 子类不能继承的内容2.3.1 引入2.3.…

Linux的软件安装

Linux命令行内的“应用商店” yum命令安装软件。 yum命令&#xff1a; yum&#xff1a;RPM软件管理器&#xff0c;用于自动化安装配置Linux软件&#xff0c;可以自动解决依赖问题。 语法&#xff1a;yum [-y] [install | remove | search] 软件名称 选项&#xff1a;-y。自动确…

java2025热点面试题之springmvc

1. 请解释Spring MVC的工作原理。 答案&#xff1a; Spring MVC是一个基于Java的MVC框架&#xff0c;用于构建Web应用程序。其工作原理如下&#xff1a; 客户端发送请求到DispatcherServlet&#xff0c;它是Spring MVC的前端控制器。DispatcherServlet查询HandlerMapping&…

[数据结构]单值二叉树

思路&#xff1a;校长跟院长比&#xff0c;院长跟主任比&#xff0c;主任跟班长比&#xff0c;班长跟舍长比&#xff0c;只要有一个不同就返回false /** * Definition for a binary tree node. * struct TreeNode { * int val; * struct TreeNode *left; * struct…

西电应用密码学与网络安全实验通关指南

西电应用密码学与网络安全实验通关指南 这是计科网络方向应用密码学与网络安全的随课实验, 占课程分数的20%. 第一次主要是介绍内容, 两周后的第二次实验主要是验收, 实验内容自己线下完成即可. 实验内容如下: 密码学及应用&#xff1a;熟悉云安全实验平台及环境&#xff0c…

Mysql进阶(一)

1. 在ubuntu下安装MySQL数据库 1.1 查看操作系统版本 操作系统版本为Ubuntu22.04. LTS lsb_release -a; 安装成功之后&#xff0c;查看mysql的状态 1.2 查看mysql的状态 1.3 登录mysql mysql -uroot -p; 1.4 退出mysql quit&#xff1b; exit&#xff1b; 2. mysql 程序的…

【向量数据库Weaviate】 和Elasticsearch的区别

Weaviate 和 Elasticsearch 是两种不同类型的数据库&#xff0c;设计目标和应用场景有显著差异。以下是它们的核心区别和适用场景的详细对比&#xff1a; 1. 设计目标与核心能力 维度WeaviateElasticsearch核心能力向量数据库 图数据库&#xff08;语义搜索优先&#xff09;全…