目录
前言
学习 flash-attention-minimal 项目中的 flash attention 实现,记录下个人学习笔记,仅供自己参考😄
refer1:https://github.com/tspeterkim/flash-attention-minimal
1. flash attention
我们先把上篇文章讲解过的 flash attention 算法伪代码实现贴在下面,方便我们在分析代码时做对应的查看:
下面是 Flash Attention 的整体流程:
一、算法基本设置:
- 输入为三个矩阵 Q , K , V ∈ R N × d \mathbf{Q},\mathbf{K},\mathbf{V} \in \mathbb{R}^{N\times d} Q,K,V∈RN×d,初始存储于低速 HBM(高带宽内存)
- 片上内存 SRAM(on-chip SRAM)的大小为 M M M
二、算法详细步骤分析:
步骤 1-2(初始化阶段)
- 步骤 1:设置分块大小
- 将序列长度 N N N 按照列块大小 B c B_c Bc 和行块大小 B r B_r Br 进行划分,使得每次在片上(on-chip SRAM)只处理一小部分 Q \mathbf{Q} Q 和 K \mathbf{K} K
- 设定每个列块的尺寸为 B c = ⌈ M 4 d ⌉ B_c=\lceil \frac{M}{4d} \rceil Bc=⌈4dM⌉,每个行块的尺寸为 B r = min ( ⌈ M 4 d ⌉ , d ) B_r=\min \left(\lceil \frac{M}{4d}\rceil,d\right) Br=min(⌈4dM⌉,d),同时确保其不超过序列长度 N N N
- 步骤 2:在片外内存 HBM 中,初始化输出矩阵 O = ( 0 ) N × d ∈ R N × d \mathbf{O}=(0)_{N\times d}\in \mathbb{R}^{N\times d} O=(0)N×d∈RN×d、归一化因子向量 ℓ = ( 0 ) N ∈ R N \ell = (0)_{N} \in \mathbb{R}^{N} ℓ=(0)N∈RN 以及最大值向量 m = ( − ∞ ) N ∈ R N m = (- \infty)_N \in \mathbb{R}^N m=(−∞)N∈RN
步骤 3-4 (输入输出分块阶段)
- 步骤 3:输入分块(tiling)
- 将输入矩阵 Q \mathbf{Q} Q 划分成 T r = ⌈ N B r ⌉ T_r= \lceil \frac{N}{B_r} \rceil Tr=⌈BrN⌉ 个子块 Q 1 , … , Q T r \mathbf{Q}_1,\ldots,\mathbf{Q}_{T_r} Q1,…,QTr,每个子块大小为 B r × d B_r \times d Br×d
- 同样地,将 K , V \mathbf{K},\mathbf{V} K,V 划分成 T c = ⌈ N B c ⌉ T_c = \lceil \frac{N}{B_c} \rceil Tc=⌈BcN⌉ 个子块 K 1 , … , K T c \mathbf{K}_1,\ldots,\mathbf{K}_{T_c} K1,…,KTc 以及 V 1 , … , V T c \mathbf{V}_1,\ldots,\mathbf{V}_{T_c} V1,…,VTc,每个子块大小为 B c × d B_c \times d Bc×d
- 步骤 4:输出与中间值分块
- 将输出矩阵 O \mathbf{O} O 同样分成 T r T_r Tr 个子块 O 1 , … , O T r \mathbf{O}_1,\ldots,\mathbf{O}_{T_r} O1,…,OTr,每个子块大小为 B r × d B_r\times d Br×d
- 中间归一化因子向量 ℓ \ell ℓ 与中间最大值向量 m m m 分别划分为 T r T_r Tr 个子块 ℓ 1 , … , ℓ T r \ell_1,\ldots,\ell_{T_r} ℓ1,…,ℓTr 以及 m 1 , … , m T r m_1,\ldots,m_{T_r} m1,…,mTr,每块大小为 B r B_r Br
步骤 5-15(分块计算阶段)
- 步骤 5:外层循环(步骤 5-15),遍历所有的键值块(
K
j
,
V
j
\mathbf{K}_j,\mathbf{V}_j
Kj,Vj),
j
=
1
,
…
,
T
c
j=1,\ldots,T_c
j=1,…,Tc
- 步骤 6:加载对应的键值块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到片上内存 SRAM 上
- 步骤 7:内层循环(步骤 7-14),遍历所有的查询块(
Q
i
\mathbf{Q}_i
Qi),
i
=
1
,
…
,
T
r
i=1,\ldots,T_r
i=1,…,Tr,每个查询块内部的具体计算过程如下:
- 步骤 8:从片外内存 HBM 上加载 Q i , O i , ℓ i , m i \mathbf{Q}_i,\mathbf{O}_i,\ell_i,m_i Qi,Oi,ℓi,mi 到片上内存 SRAM 上
- 步骤 9:在片上内存 SRAM 上计算 attention 分数矩阵: S i j = Q i K j T ∈ R B r × B c \mathbf{S}_{ij}=\mathbf{Q}_i\mathbf{K}_j^T\in \mathbb{R}^{B_r\times B_c} Sij=QiKjT∈RBr×Bc
- 步骤 10:在片上内存上对注意力分数进行分块式 softmax:
- 计算当前子块最大值(用于数值稳定): m ~ i j = r o w m a x ( S i j ) ∈ R B r \tilde{m}_{ij}=\mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r} m~ij=rowmax(Sij)∈RBr
- 计算指数化矩阵: P ~ i j = e x p ( S i j − m ~ i j ) ∈ R B r × B c \tilde{\mathbf{P}}_{ij} = \mathrm{exp}(\mathbf{S}_{ij}-\tilde{m}_{ij})\in \mathbb{R}^{B_r\times B_c} P~ij=exp(Sij−m~ij)∈RBr×Bc
- 计算局部归一化因子(每行求和): ℓ ~ i j = r o w s u m ( P ~ i j ) ∈ R B r \tilde{\ell}_{ij}=\mathrm{rowsum}(\tilde{\mathbf{P}}_{ij})\in \mathbb{R}^{B_r} ℓ~ij=rowsum(P~ij)∈RBr
- 步骤 11:更新全局归一化因子:
- 新的最大值: m i n e w = max ( m i , m ~ i j ) ∈ R B r m_i^{\mathrm{new}}=\max(m_i,\tilde{m}_{ij})\in \mathbb{R}^{B_r} minew=max(mi,m~ij)∈RBr
- 新的归一化向量: ℓ i n e w = e m i − m i n e w ℓ i + e m ~ i j − m i n e w ℓ ~ i j ∈ R B r \ell_i^{\mathrm{new}}=e^{m_i-m_i^{\mathrm{new}}}\ell_i+e^{\tilde{m}_{ij}-m_i^{new}}\tilde{\ell}_{ij}\in \mathbb{R}^{B_r} ℓinew=emi−minewℓi+em~ij−minewℓ~ij∈RBr
- 步骤 12:计算注意力输出 O i \mathbf{O}_i Oi 并写回 HBM: O i ← d i a g ( ℓ i n e w ) − 1 ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) \mathbf{O}_i\leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}(\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}_i+e^{\tilde{m}_{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j) Oi←diag(ℓinew)−1(diag(ℓi)emi−minewOi+em~ij−minewP~ijVj)
- 步骤 13:更新全局向量 ℓ i ← ℓ i n e w , m i ← m i n e w \ell_i \leftarrow \ell_i^{\mathrm{new}},m_i\leftarrow m_i^{\mathrm{new}} ℓi←ℓinew,mi←minew 写回 HBM
- 步骤 14:内层循环结束
- 步骤 15:外层循环结束
步骤 16(输出结果阶段)
- 步骤 16:将输出计算的 attention 结果 O \mathbf{O} O 返回
2. flash-attention-minimal
flash-attention-minimal 这个 repo 使用 CUDA 和 PyTorch 对 Flash Attention 进行最小化的重新实现。对于 CUDA 初学者(比如博主)来说,官方的实现可能会让人望而生畏,因此这个 repo 试图做到小而具有教育意义
- 整个 forward 过程在
flash.cu
中仅编写了约 100 行的代码 - 变量名沿用了原始 paper 中的符号
整个项目仅包含以下三个文件:
bench.py
:测试和基准脚本- 用纯 pytorch 的
manual_attn
做 baseline - 用
minimal_attn.forward(...)
调用自定义实现的 flash attention cuda 实现 - 检查两者输出是否一致
- 测试在 gpu 上的执行时间,来确认自定义 flash attention kernel 是否比标准 self-attention 更快
- 用纯 pytorch 的
main.cpp
:桥接文件- 提供 python
⇔
\Leftrightarrow
⇔ c++/cuda 的绑定代码,让 python 可以直接调用
flash.cu
中定义的forward(...)
函数
- 提供 python
⇔
\Leftrightarrow
⇔ c++/cuda 的绑定代码,让 python 可以直接调用
flash.cu
:核心文件- 实现 flash attention 逻辑的 cuda 核函数
forward_kernel
以及一个对外暴露的 c++ 函数torch::Tensor forward(...)
- 实现 flash attention 逻辑的 cuda 核函数
整个项目可以通过如下指令来运行:
git clone https://github.com/tspeterkim/flash-attention-minimal
cd flash-attention-minimal
# flashattn 为虚拟环境名称
conda activate flashattn
python bench.py
执行完成后输出如下图所示:
下面我们来对代码文件逐个分析,重点是 flash.cu
文件
3. bench.py
bench.py
代码如下:
import math
import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load
# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])
# Use small model params, otherwise slower than manual attention. See caveats in README.
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64
q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
print('=== profiling manual attention ===')
# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def manual_attn(q, k, v):
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
att = F.softmax(att, dim=-1)
y = att @ v
return y
with torch.autograd.profiler.profile(use_cuda=True) as prof:
manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
print('=== profiling minimal flash attention === ')
with torch.autograd.profiler.profile(use_cuda=True) as prof:
minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))
下面我们来逐代码分析:
import math
import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load
# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])
1. 模块导入
load(...)
会使用 pytorch 的 c++ extensions 机制编译并加载给定的 c++/cuda 源文件,最终会生成一个名为minimal_attn
的 python 扩展模块- 这里的
minimal_attn
就是在main.cpp
里定义的PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
,这样就能拿到其中的forward
函数
2. 设置测试参数
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64
q, k, v
的维度等于(batch_size, n_head, seq_len, head_embd)
3. 初始化测试张量
q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
- 随机生成输入张量,并放到 GPU(
.cuda()
)上
4. self-attention 函数
def manual_attn(q, k, v):
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
att = F.softmax(att, dim=-1)
y = att @ v
return y
- 这是用纯 pytorch 实现的标准注意力: O = s o f t m a x ( Q K T D ) × V O = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{D}}\right)\times V O=softmax(DQKT)×V
manual_attn
属于 baseline,方便与自定义实现的 flash attention 对比
5. profile self-attention(时间测量)
with torch.autograd.profiler.profile(use_cuda=True) as prof:
manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
- 这里用 pytorch 的 profiler 来记录 GPU 时间,得到
manual_attn
的执行开销
6. profile flash attention(时间测量)
- 调用
minimal_attn.forward(q, k, v)
,这里会跳转到 c++/cuda 代码 - 同样用 profiler 测量执行时间,给出统计信息(算子名、总时间等)
7. 结果正确对比
print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))
- 比较
minimal_result
(GPU kernel 版本)和manual_result
(纯 pytorch 版本)的数值是否在一定误差范围内一致 - 这里设
rtol=0, atol=1e-2
说明只要它们的绝对误差不超过1e-2
就认为结果接近,原因是浮点精度以及 flash attention 里分块累计时会有一些数值误差
总的来说 bench.py
脚本的作用是一个性能测试和正确性验证脚本:
- 1. 先编译并导入我们的 c++/cuda 扩展模块(
minimal_attn
) - 2. 用一个小模型规模生成随机数据
q, k, v
- 3. 分别运行 self-attention 和 flash attention kernel,比较执行时间并检查输出差异
4. main.cpp
main.cpp
代码如下:
#include <torch/extension.h>
torch::Tensor forward(torch::Tensor q, torch::Tensor k, torch::Tensor v);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", torch::wrap_pybind_function(forward), "forward");
}
下面我们来逐代码分析:
1. 包含头文件
#include <torch/extension.h>
- 这里引入了 pytorch c++ extension 提供的头文件,用于与 pytorch/pybind11 做交互
torch::Tensor
、PYBIND11_MODULE
、torch::wrap_pybind_function
等都来自这些接口
2. 声明外部函数
torch::Tensor forward(torch::Tensor q, torch::Tensor k, torch::Tensor v);
- 这里只声明了一个名为
forward
的函数,它的具体实现在flash.cu
中
3. 定义 pytorch extension 模块
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", torch::wrap_pybind_function(forward), "forward");
}
PYBIND11_MODULE
用于把 c++/cuda 代码打包成一个 python 模块,模块名即TORCH_EXTENSION_NAME
- 这里
m.def("forward", ...)
相当于给 python 端注册了一个函数,名字就叫forward
,当我们在 python 中调用minimal_attn.forward(...)
时,底层就会调用到我们在flash.cu
中定义的 c++ 函数forward(...)
torch::wrap_pybind_function
是 pytorch 对pybind11::function
的一层封装,用于让编译器自动帮我们传入/传出的torch::Tensor
参数做 pytorch ⇔ \Leftrightarrow ⇔ c++ 的转换
总的来说 main.cpp
的主要功能是桥接,它把 flash.cu
里的核心算法函数 forward(...)
注册为一个 python 模块的接口函数,供 python 调用
5. flash.cu
flash.cu
代码如下:
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
float* l, float *m, float* O) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads(); // such that the inner loop can use the correct Kj, Vj
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
// Write O, l, m to HBM
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop
}
}
torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
// TODO: determine Bc, Br dynamically
const int Bc = 32; const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// Initialize O, l, m to HBM
auto O = torch::zeros_like(Q);
auto l = torch::zeros({B, nh, N});
auto m = torch::full({B, nh, N}, -INFINITY);
torch::Device device(torch::kCUDA);
l = l.to(device); m = m.to(device);
// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Bc); // Bc threads per block
forward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
return O;
}
在 flash.cu
中实现了所有 flash attention 的细节逻辑,包括分块 softmax、累加输出、共享内存加载、并行计算等
下面我们分 forward
接口函数和 forward_kernel
逻辑实现核函数两部分来讲解
5.1 forward函数
forward
函数是 forward_kernel
核函数的启动函数,其代码如下:
torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
// TODO: determine Bc, Br dynamically
const int Bc = 32; const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// Initialize O, l, m to HBM
auto O = torch::zeros_like(Q);
auto l = torch::zeros({B, nh, N});
auto m = torch::full({B, nh, N}, -INFINITY);
torch::Device device(torch::kCUDA);
l = l.to(device); m = m.to(device);
// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Bc); // Bc threads per block
forward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
return O;
}
下面我们来逐代码分析:
1. 函数签名和参数
torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V)
- 这是在 c++/cuda 上实现的一个包装函数,直接对外暴露 pytorch tensor 接口
Q, K, V
是来自 pytorch 的张量,形状为 ( B , n h , N , d ) (B,nh,N,d) (B,nh,N,d)- B B B:batch size
- n h nh nh:number of heads
- N N N:序列长度
- d d d:每个头的特征维度
此函数将返回一个和 Q
同形状的输出张量
O
\mathbf{O}
O
2. 函数体主要逻辑
2.1 确定分块大小 Bc, Br
// TODO: determine Bc, Br dynamically
const int Bc = 32;
const int Br = 32;
- 做法:硬编码设置块大小 B c = 32 , B r = 32 B_c=32,B_r=32 Bc=32,Br=32
- 对应算法:这相当于在 flash attention 算法 步骤 1(分块大小设定)时,指定了列块大小、行块大小都为 32
- 值得注意的是,通常实际实现中我们会根据硬件的 shared memory 大小、N、d 等情况动态决定最优块大小,这边简化成了固定值
2.2 获取张量形状并计算 Tc, Tr
const int B = Q.size(0); // batch
const int nh = Q.size(1); // num_heads
const int N = Q.size(2); // sequence length
const int d = Q.size(3); // feature dim (per head)
const int Tc = ceil((float) N / Bc);
const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
- 从 pytorch 的张量
Q
中获取四个维度:(B, nh, N, d)
Tc = ceil(N / Bc)
,Tr = ceil(N / Br)
:- 对应算法:这正是 步骤 3 所述,将键值矩阵 K , V \mathbf{K},\mathbf{V} K,V 沿序列维度分成 T c T_c Tc 块,将查询矩阵 Q \mathbf{Q} Q 沿序列维度分成 T r T_r Tr 块
softmax_scale = 1.0 / sqrt(d)
:缩放因子 1 d \frac{1}{\sqrt{d}} d1
2.3 初始化输出张量 O \mathbf{O} O、向量 ℓ \ell ℓ、向量 m m m
auto O = torch::zeros_like(Q);
auto l = torch::zeros({B, nh, N});
auto m = torch::full({B, nh, N}, -INFINITY);
torch::Device device(torch::kCUDA);
l = l.to(device);
m = m.to(device);
- 对应算法:这是 步骤 2 所描述的在片外内存 HBM 初始化输出矩阵 O \mathbf{O} O、归一化因子向量 ℓ \ell ℓ 和最大值向量 m m m
auto O = torch::zeros_like(Q);
:创建一个和 Q \mathbf{Q} Q 同形状的全 0 张量,后面会在 kernel 中逐步累加注意力结果auto l = torch::zeros({B, nh, N});
:初始化 ℓ \ell ℓ 为 0auto m = torch::full({B, nh, N}, -INFINITY);
:初始化 m m m 为 − ∞ -\infty −∞l = l.to(device);
,m = m.to(device);
:将l, m
放到 GPU 上
2.4 计算需要的共享内存大小并检查
// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \n",
max_sram_size, sram_size);
- 对应算法:是我们在 Require(算法基本设置)里所说的要确保片上内存 SRAM 足以容纳当前分块所需的 Q i , K j , V j , S i j \mathbf{Q}_i,\mathbf{K}_j,\mathbf{V}_j,\mathbf{S}_{ij} Qi,Kj,Vj,Sij 等中间结果
- 在
forward_kernel
中我们使用了extern __shared__ float sram[];
来动态分配 shared memory(on-chip SRAM) - 该内存需要容纳:
Kj, Vj, Qi
三个分块:Kj, Vj
分块各有 B c × d B_c\times d Bc×d 的大小,Qi
分块有 B r × d B_r \times d Br×d 的大小,总共3 * Bc * d
floats(本例中 B c = B r = 32 B_c=B_r=32 Bc=Br=32)S
矩阵: B c × B r B_c \times B_r Bc×Br 的大小
- 通过
cudaDevAttrMaxSharedMemoryPerBlock
查询硬件支持的单个 block 最大共享内存max_sram_size
,用来做一个检查或调试信息打印
2.5 设置 kernel 的 gridDim 和 blockDim 并启动 kernel
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Bc); // Bc threads per block
forward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
grid_dim(B, nh)
:2D-layout 网格维度,说明 GPU 上会有B * nh
个线程块,每个 block 负责处理 一个 batch、一个 head 的全部序列 attentionblock_dim(Bc)
:1D-layout 块维度,说明每个 block 里面有Bc
个线程(本例中Bc = 32
),threadIdx.x
范围从 0 到 31- 在
forward_kernel
里,我们将一整个 tile ( B c × d ) (B_c \times d) (Bc×d) 的数据加载到 shared memory,并让Bc
个线程并行计算 - 每个线程处理 1 个 token 即 1 × d 1\times d 1×d 维度的数据
- 在
sram_size
:这是前面计算的 shared memory 大小,用来传给<<<grid_dim, block_dim, sram_size>>>
作为第三个参数,表示动态分配多少 shared memory- 传入 kernel 的参数
Q.data_ptr<float>()
,K.data_ptr<float>()
,V.data_ptr<float>()
:指向 pytorch 张量 Q , K , V \mathbf{Q},\mathbf{K},\mathbf{V} Q,K,V 在 GPU 全局内存中的数据地址N, d, Tc, Tr, Bc, Br, softmax_scale
:各自维度和分块参数,以及 softmax 缩放因子l.data_ptr<float>()
,m.data_ptr<float>()
,O.data_ptr<float>()
:同理,指向 ℓ , m , O \ell,m,\mathbf{O} ℓ,m,O 在显存中的地址
2.6 返回输出结果
return O;
- 返回输出张量 O \mathbf{O} O
总的来说,forward
函数在主机端(Host 端)做了三件事情:配置分块、分配输出与中间缓存(O, l, m
)、启动 CUDA 核函数,在 forward_kernel
核函数内才真正完成了 Flash Attention 的分块计算,下面我们就来看看 forward_kernel
是怎么做的
5.2 forward_kernel函数
forward_kernel
核函数代码如下:
__global__
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
float* l, float *m, float* O) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads(); // such that the inner loop can use the correct Kj, Vj
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
// Write O, l, m to HBM
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop
}
}
下面我们来逐代码分析:
核函数签名如下:
__global__
void forward_kernel(const float* Q, const float* K, const float* V,
const int N, const int d,
const int Tc, const int Tr,
const int Bc, const int Br,
const float softmax_scale,
float* l, float* m, float* O) {
...
}
Q, K, V
:全局内存 HBM 中的输入查询、键、值矩阵指针N, d
:序列长度和特征维度Tc, Tr
:分块个数,对应外层和内存循环需要循环的次数,分别是 ⌈ N B c ⌉ \lceil \frac{N}{B_c} \rceil ⌈BcN⌉ 和 ⌈ N B r ⌉ \lceil \frac{N}{B_r} \rceil ⌈BrN⌉Bc, Br
:列块和行块的分块大小softmax_scale
:缩放因子 1 d \frac{1}{\sqrt{d}} d1,在计算 S i j \mathbf{S}_{ij} Sij 时需要乘以这个缩放因子l, m
:分块存储的归一化因子 ℓ \ell ℓ 和最大值向量 m m m- O O O:最终输出的注意力结果,与 Q \mathbf{Q} Q 同形状
1. 线程、块、批次与多头索引
int tx = threadIdx.x;
int bx = blockIdx.x;
int by = blockIdx.y; // batch 和 head 的二维索引
tx
表示线程(thread)在线程块(block)内的一维索引,范围是0 ~ (Bc - 1)
bx
是线程块在 x 方向上的索引by
是线程块在 y 方向上的索引- 这里
dim3 grid_dim(B, nh)
,所以bx
对应 batch 维度、by
对应多头维度
2. 计算在 Q/K/V/O/l/m 上的偏移量
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);
int lm_offset = (bx * gridDim.y * N) + (by * N);
- 由于在外层调用 kernel 时,网格大小设置为
(B, nh)
,其中 B 是 batch 数,nh 是 head 数,因此:gridDim.y = nh
qkv_offset
用于在 (batch, head) 两个维度上,定位到 Q , K , V \mathbf{Q},\mathbf{K},\mathbf{V} Q,K,V 的正确起始地址lm_offset
用于定位 ℓ , m \ell,m ℓ,m 在全局内存中的偏移
3. 声明共享内存并分块
extern __shared__ float sram[];
int tile_size = Bc * d; // 大小 = 每个 tile 中 (Bc 行) * d 列
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
extern __shared__ float sram[];
表示这个核函数会动态分配 shared memory(片上 SRAM),其大小在<<<...>>>
调用时由第三个参数指定- 这里将这片 shared memory 切分成 4 个部分:
Qi
:用于存放 Q i \mathbf{Q}_i Qi 这个 tile,大小为 ( B r × d ) (B_r\times d) (Br×d)kj
:用于存放 K j \mathbf{K}_j Kj 这个 tile,大小为 ( B c × d ) (B_c\times d) (Bc×d)Vj
:用于存放 V j \mathbf{V}_j Vj 这个 tile,大小为 ( B c × d ) (B_c \times d) (Bc×d)S
:用于存放分块注意力得分矩阵 S i j \mathbf{S}_{ij} Sij,大小为 B c × B r B_c \times B_r Bc×Br
- 由于我们需要一次性把这几个子块都放到 shared memory 中做计算,因此必须分配足够大的 shared memory
4. 外层循环,遍历所有键值块
for (int j = 0; j < Tc; j++) {
...
}
- 这对应算法中的 外层循环(步骤 5),也就是遍历所有的 K j \mathbf{K}_j Kj 和 V j \mathbf{V}_j Vj 的列块
4.1 加载 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到共享内存
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
- 对应算法步骤 6:从片外 HBM 中加载当前列块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到片上共享内存
tile_size * j = (Bc * d) * j
表示第 j 个列块在全局 K/V 内存中的起始位置tx
表示线程索引,这里假设每个线程拷贝一行的(即 Bc 行的其中一行)d 个元素,所以 Kj/Vj 的索引是(tx * d) + x
__syncthreads();
保证 block 内所有线程都完成数据加载后再进行后续操作,以确保共享内存里Kj, Vj
是完整的
关于索引的计算博主绘制了一个草图来说明:
Note:我们以键矩阵为例来说明(值矩阵类似)
K
K
K 在内存中是连续存储的,数据总长度是 B * nh * N * d
,现在我们的目的是找到全局
K
K
K 内存中某个 tile 块中的某个元素(图示红色小块)的全局位置索引
我们可以借用 NVIDIA 的 gridDim 和 blockDim 类似的布局思想,假设 HBM 中存储的
K
K
K 是以 2D-layout 布局,总共有 B * nh
个大块,每个大块的维度是 N * d
每个小块被切分成了 Tc
个 tile(本例中为 2),则图中红色小块的全局位置索引可以通过 qkv_offset + (tile_size * j) + (tx * d) + x
获取,其中:
qkv_offset
定位到B * nh
个大块中具体哪个大块tile_size * j
定位到Tc
个 tile 中的哪个 tiletx * d
定位到哪个 thread 处理x
定位到属于d
维度的哪个元素
5. 内层循环:遍历所有查询块
for (int i = 0; i < Tr; i++) {
...
}
- 对应算法步骤 7:遍历行块
Q
i
\mathbf{Q}_i
Qi,每个
i
表示一个查询子块(行块),大小为 B r × d B_r \times d Br×d
5.1 加载 Q i \mathbf{Q}_i Qi、 O i \mathbf{O}_i Oi 及 ℓ i , m i \ell_i,m_i ℓi,mi
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
- 对应算法步骤 8:把当前的
Q
i
\mathbf{Q}_i
Qi 分块加载进共享内存
Qi
,同时从全局内存读入对应该行块的m_i
、l_i
到寄存器 - 因为要对
ℓ
i
\ell_i
ℓi(softmax 归一化因子)和
m
i
m_i
mi(最大值)做更新,需要先读旧的值。这里用
row_l_prev
、row_m_prev
表示 ℓ i \ell_i ℓi 与 m i m_i mi 的旧值
Note:这里并没有显式加载
O
i
\mathbf{O}_i
Oi 到共享内存,而是会在更新时直接访问全局内存的 O[...]
进行读写
5.2 计算分块注意力分数 S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i\mathbf{K}^T_j Sij=QiKjT
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
- 对应算法步骤 9:计算 S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i\mathbf{K}^T_j Sij=QiKjT
- 其中:
Qi[(tx * d) + x]
取到该线程处理的某行 Q i \mathbf{Q}_i QiKj[(y * d) + x]
取到 K j \mathbf{K}_j Kj 的第y
行- 然后逐元素相乘累加得到
sum
- 乘以缩放因子
softmax_scale
- 将结果写入共享内存
S[(Bc * tx) + y]
- 这里也顺便计算行最大值
row_m
,相当于 m ~ i j = r o w m a x ( S i j ) \tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) m~ij=rowmax(Sij)
5.3 做局部 softmax: P ~ i j = e x p ( S i j − m ~ i j ) \tilde{\mathbf{P}}_{ij} = \mathrm{exp}(\mathbf{S}_{ij}-\tilde{m}_{ij}) P~ij=exp(Sij−m~ij),并求和
// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
- 对应算法步骤 10:对注意力分数做分块式 softmax:
- 减去行最大值
row_m
(即 m ~ i j \tilde{m}_{ij} m~ij)以做数值稳定 - 指数化 exp ( ⋅ ) \exp(\cdot) exp(⋅),得到 P ~ i j \tilde{\mathbf{P}}_{ij} P~ij
- 行向量求和,得到局部归一化因子
ℓ
~
i
j
\tilde{\ell}_{ij}
ℓ~ij,对应代码中的
row_l
变量
- 减去行最大值
5.4 计算新的全局最大值与归一化因子
// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev)
+ (__expf(row_m - row_m_new) * row_l);
- 对应算法步骤 11:计算新全局向量
m
i
n
e
w
,
ℓ
i
n
e
w
m_i^{\mathrm{new}},\ell_i^{\mathrm{new}}
minew,ℓinew,参考公式:
- m i n e w = max ( m i , m ~ i j ) m_i^{\mathrm{new}}=\max(m_i,\tilde{m}_{ij}) minew=max(mi,m~ij)
- ℓ i n e w = e m i − m i n e w ℓ i + e m ~ i j − m i n e w ℓ ~ i j \ell_i^{\mathrm{new}}=e^{m_i-m_i^{\mathrm{new}}}\ell_i+e^{\tilde{m}_{ij}-m_i^{new}}\tilde{\ell}_{ij} ℓinew=emi−minewℓi+em~ij−minewℓ~ij
- 这里
row_m_prev
相当于上一次累积的 m i m_i mi,row_m
是本块算出的 m ~ i j \tilde{m}_{ij} m~ij,分别对应公式中的 m i m_i mi 和 m ~ i j \tilde{m}_{ij} m~ij row_l_prev
是旧的 ℓ i \ell_i ℓi,row_l
是本块的 ℓ ~ i j \tilde{\ell}_{ij} ℓ~ijrow_m_new
、row_l_new
就是更新后的 m i n e w m_i^{\mathrm{new}} minew 和 ℓ i n e w \ell_i^{\mathrm{new}} ℓinew
5.5 更新输出 O i \mathbf{O}_i Oi
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x]
= (1 / row_l_new) * (
(row_l_prev * __expf(row_m_prev - row_m_new)
* O[qkv_offset + (tile_size * i) + (tx * d) + x])
+ (__expf(row_m - row_m_new) * pv)
);
}
- 对应算法步骤 12:分块累加更新 O i \mathbf{O}_i Oi,根据公式:
- O i ← d i a g ( ℓ i n e w ) − 1 ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) \mathbf{O}_i\leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}(\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}_i+e^{\tilde{m}_{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j) Oi←diag(ℓinew)−1(diag(ℓi)emi−minewOi+em~ij−minewP~ijVj)
- 对应到代码里:
pv
对应 P ~ i j V j \tilde{\mathbf{P}}_{ij}\mathbf{V}_j P~ijVj 逐元素相乘再累加(row_l_prev * __expf(row_m_prev - row_m_new) * O[...]
对应旧的 O i \mathbf{O}_i Oi 部分乘上相应的因子 d i a g ( ℓ i ) e m i − m i n e w \mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}} diag(ℓi)emi−minew__expf(row_m - row_m_new) * pv
对应新的部分(1 / row_l_new)
则是对应乘上的 d i a g ( ℓ i n e w ) − 1 \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} diag(ℓinew)−1 这个归一化因子
5.6 写回 ℓ i , m i \ell_i,m_i ℓi,mi
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
- 对应算法步骤 13:将新的 m i n e w , ℓ i n e w m_i^{\mathrm{new}}, \ell_i^{\mathrm{new}} minew,ℓinew 写回全局内存 HBM
- 完成对第
i
块(行块)的处理,内层循环推进到下一块
__syncthreads();
- 对应算法步骤 14:内层循环结束(
j++
继续下一个 tile),用__syncthreads()
确保当前的Kj, Vj
处理完成
6. 外层循环结束
- 对应算法步骤 15:当所有
j
都完成后,外层循环也结束,即完成了遍历所有 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj
总的来说,flash.cu
这份 CUDA 代码是一个比较直接的 flash attention 分块公式实现:每次加载一块
K
j
,
V
j
\mathbf{K}_j,\mathbf{V}_j
Kj,Vj,然后对所有
Q
i
\mathbf{Q}_i
Qi 进行 partial softmax、partial output 累加,更新
O
i
,
ℓ
i
,
m
i
\mathbf{O}_i,\ell_i,m_i
Oi,ℓi,mi,当所有块都处理完后,就得到完整的注意力输出。
结语
在 flash-attention-minimal 中,最小化的 flash attention 实现过程还是非常清晰的,其大致思路是:
1. 先在 cpu/python 端根据 N , d N,d N,d 等参数设置好分块大小 B c , B r B_c,B_r Bc,Br,初始化 O , ℓ , m \mathbf{O},\ell, m O,ℓ,m
2. 在 gpu 端,使用外层循环遍历所有列块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj,并将它们加载到共享内层
3. 内层循环遍历所有行块 Q i \mathbf{Q}_i Qi,并同样加载到共享内层后,计算分块注意力分数 S i j \mathbf{S}_{ij} Sij,执行分块 softmax,更新全局归一化因子 ℓ \ell ℓ 与输出 O \mathbf{O} O
4. 循环完成后, O \mathbf{O} O 就是完整的注意力输出
代码里每一部分(加载、乘法、指数化、最大值即归一化更新、输出累加写回)都能和 flash attention 论文中的算法伪代码实现对应上,作为 flash attention 代码实现的入门还是非常不错的
大家感兴趣的可以看看,整个实现非常的简洁🤗