Flash Attention代码实现

前言

学习 flash-attention-minimal 项目中的 flash attention 实现,记录下个人学习笔记,仅供自己参考😄

refer1:https://github.com/tspeterkim/flash-attention-minimal

refer2:https://chatgpt.com/

1. flash attention

我们先把上篇文章讲解过的 flash attention 算法伪代码实现贴在下面,方便我们在分析代码时做对应的查看:

在这里插入图片描述

下面是 Flash Attention 的整体流程:

一、算法基本设置:

  • 输入为三个矩阵 Q , K , V ∈ R N × d \mathbf{Q},\mathbf{K},\mathbf{V} \in \mathbb{R}^{N\times d} Q,K,VRN×d,初始存储于低速 HBM(高带宽内存)
  • 片上内存 SRAM(on-chip SRAM)的大小为 M M M

二、算法详细步骤分析:

步骤 1-2(初始化阶段)

  • 步骤 1:设置分块大小
    • 将序列长度 N N N 按照列块大小 B c B_c Bc 和行块大小 B r B_r Br 进行划分,使得每次在片上(on-chip SRAM)只处理一小部分 Q \mathbf{Q} Q K \mathbf{K} K
    • 设定每个列块的尺寸为 B c = ⌈ M 4 d ⌉ B_c=\lceil \frac{M}{4d} \rceil Bc=4dM,每个行块的尺寸为 B r = min ⁡ ( ⌈ M 4 d ⌉ , d ) B_r=\min \left(\lceil \frac{M}{4d}\rceil,d\right) Br=min(4dM,d),同时确保其不超过序列长度 N N N
  • 步骤 2:在片外内存 HBM 中,初始化输出矩阵 O = ( 0 ) N × d ∈ R N × d \mathbf{O}=(0)_{N\times d}\in \mathbb{R}^{N\times d} O=(0)N×dRN×d、归一化因子向量 ℓ = ( 0 ) N ∈ R N \ell = (0)_{N} \in \mathbb{R}^{N} =(0)NRN 以及最大值向量 m = ( − ∞ ) N ∈ R N m = (- \infty)_N \in \mathbb{R}^N m=()NRN

步骤 3-4 (输入输出分块阶段)

  • 步骤 3:输入分块(tiling)
    • 将输入矩阵 Q \mathbf{Q} Q 划分成 T r = ⌈ N B r ⌉ T_r= \lceil \frac{N}{B_r} \rceil Tr=BrN 个子块 Q 1 , … , Q T r \mathbf{Q}_1,\ldots,\mathbf{Q}_{T_r} Q1,,QTr,每个子块大小为 B r × d B_r \times d Br×d
    • 同样地,将 K , V \mathbf{K},\mathbf{V} K,V 划分成 T c = ⌈ N B c ⌉ T_c = \lceil \frac{N}{B_c} \rceil Tc=BcN 个子块 K 1 , … , K T c \mathbf{K}_1,\ldots,\mathbf{K}_{T_c} K1,,KTc 以及 V 1 , … , V T c \mathbf{V}_1,\ldots,\mathbf{V}_{T_c} V1,,VTc,每个子块大小为 B c × d B_c \times d Bc×d
  • 步骤 4:输出与中间值分块
    • 将输出矩阵 O \mathbf{O} O 同样分成 T r T_r Tr 个子块 O 1 , … , O T r \mathbf{O}_1,\ldots,\mathbf{O}_{T_r} O1,,OTr,每个子块大小为 B r × d B_r\times d Br×d
    • 中间归一化因子向量 ℓ \ell 与中间最大值向量 m m m 分别划分为 T r T_r Tr 个子块 ℓ 1 , … , ℓ T r \ell_1,\ldots,\ell_{T_r} 1,,Tr 以及 m 1 , … , m T r m_1,\ldots,m_{T_r} m1,,mTr,每块大小为 B r B_r Br

步骤 5-15(分块计算阶段)

  • 步骤 5:外层循环(步骤 5-15),遍历所有的键值块( K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj), j = 1 , … , T c j=1,\ldots,T_c j=1,,Tc
    • 步骤 6:加载对应的键值块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到片上内存 SRAM 上
    • 步骤 7:内层循环(步骤 7-14),遍历所有的查询块( Q i \mathbf{Q}_i Qi), i = 1 , … , T r i=1,\ldots,T_r i=1,,Tr,每个查询块内部的具体计算过程如下:
      • 步骤 8:从片外内存 HBM 上加载 Q i , O i , ℓ i , m i \mathbf{Q}_i,\mathbf{O}_i,\ell_i,m_i Qi,Oi,i,mi 到片上内存 SRAM 上
      • 步骤 9:在片上内存 SRAM 上计算 attention 分数矩阵: S i j = Q i K j T ∈ R B r × B c \mathbf{S}_{ij}=\mathbf{Q}_i\mathbf{K}_j^T\in \mathbb{R}^{B_r\times B_c} Sij=QiKjTRBr×Bc
      • 步骤 10:在片上内存上对注意力分数进行分块式 softmax:
        • 计算当前子块最大值(用于数值稳定): m ~ i j = r o w m a x ( S i j ) ∈ R B r \tilde{m}_{ij}=\mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r} m~ij=rowmax(Sij)RBr
        • 计算指数化矩阵: P ~ i j = e x p ( S i j − m ~ i j ) ∈ R B r × B c \tilde{\mathbf{P}}_{ij} = \mathrm{exp}(\mathbf{S}_{ij}-\tilde{m}_{ij})\in \mathbb{R}^{B_r\times B_c} P~ij=exp(Sijm~ij)RBr×Bc
        • 计算局部归一化因子(每行求和): ℓ ~ i j = r o w s u m ( P ~ i j ) ∈ R B r \tilde{\ell}_{ij}=\mathrm{rowsum}(\tilde{\mathbf{P}}_{ij})\in \mathbb{R}^{B_r} ~ij=rowsum(P~ij)RBr
      • 步骤 11:更新全局归一化因子:
        • 新的最大值: m i n e w = max ⁡ ( m i , m ~ i j ) ∈ R B r m_i^{\mathrm{new}}=\max(m_i,\tilde{m}_{ij})\in \mathbb{R}^{B_r} minew=max(mi,m~ij)RBr
        • 新的归一化向量: ℓ i n e w = e m i − m i n e w ℓ i + e m ~ i j − m i n e w ℓ ~ i j ∈ R B r \ell_i^{\mathrm{new}}=e^{m_i-m_i^{\mathrm{new}}}\ell_i+e^{\tilde{m}_{ij}-m_i^{new}}\tilde{\ell}_{ij}\in \mathbb{R}^{B_r} inew=emiminewi+em~ijminew~ijRBr
      • 步骤 12:计算注意力输出 O i \mathbf{O}_i Oi 并写回 HBM: O i ← d i a g ( ℓ i n e w ) − 1 ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) \mathbf{O}_i\leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}(\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}_i+e^{\tilde{m}_{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j) Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijVj)
      • 步骤 13:更新全局向量 ℓ i ← ℓ i n e w , m i ← m i n e w \ell_i \leftarrow \ell_i^{\mathrm{new}},m_i\leftarrow m_i^{\mathrm{new}} iinew,miminew 写回 HBM
    • 步骤 14:内层循环结束
  • 步骤 15:外层循环结束

步骤 16(输出结果阶段)

  • 步骤 16:将输出计算的 attention 结果 O \mathbf{O} O 返回

2. flash-attention-minimal

flash-attention-minimal 这个 repo 使用 CUDA 和 PyTorch 对 Flash Attention 进行最小化的重新实现。对于 CUDA 初学者(比如博主)来说,官方的实现可能会让人望而生畏,因此这个 repo 试图做到小而具有教育意义

  • 整个 forward 过程在 flash.cu 中仅编写了约 100 行的代码
  • 变量名沿用了原始 paper 中的符号

整个项目仅包含以下三个文件:

  • bench.py:测试和基准脚本
    • 用纯 pytorch 的 manual_attn 做 baseline
    • minimal_attn.forward(...) 调用自定义实现的 flash attention cuda 实现
    • 检查两者输出是否一致
    • 测试在 gpu 上的执行时间,来确认自定义 flash attention kernel 是否比标准 self-attention 更快
  • main.cpp:桥接文件
    • 提供 python ⇔ \Leftrightarrow c++/cuda 的绑定代码,让 python 可以直接调用 flash.cu 中定义的 forward(...) 函数
  • flash.cu:核心文件
    • 实现 flash attention 逻辑的 cuda 核函数 forward_kernel 以及一个对外暴露的 c++ 函数 torch::Tensor forward(...)

整个项目可以通过如下指令来运行:

git clone https://github.com/tspeterkim/flash-attention-minimal
cd flash-attention-minimal
# flashattn 为虚拟环境名称
conda activate flashattn
python bench.py

执行完成后输出如下图所示:

在这里插入图片描述

下面我们来对代码文件逐个分析,重点是 flash.cu 文件

3. bench.py

bench.py 代码如下:

import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load

# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])

# Use small model params, otherwise slower than manual attention. See caveats in README.
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()

print('=== profiling manual attention ===')

# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def manual_attn(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling minimal flash attention === ')

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))

下面我们来逐代码分析:

import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load

# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])

1. 模块导入

  • load(...) 会使用 pytorch 的 c++ extensions 机制编译并加载给定的 c++/cuda 源文件,最终会生成一个名为 minimal_attn 的 python 扩展模块
  • 这里的 minimal_attn 就是在 main.cpp 里定义的 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m),这样就能拿到其中的 forward 函数

2. 设置测试参数

batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64
  • q, k, v 的维度等于 (batch_size, n_head, seq_len, head_embd)

3. 初始化测试张量

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
  • 随机生成输入张量,并放到 GPU(.cuda())上

4. self-attention 函数

def manual_attn(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y
  • 这是用纯 pytorch 实现的标准注意力: O = s o f t m a x ( Q K T D ) × V O = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{D}}\right)\times V O=softmax(D QKT)×V
  • manual_attn 属于 baseline,方便与自定义实现的 flash attention 对比

5. profile self-attention(时间测量)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
  • 这里用 pytorch 的 profiler 来记录 GPU 时间,得到 manual_attn 的执行开销

6. profile flash attention(时间测量)

  • 调用 minimal_attn.forward(q, k, v),这里会跳转到 c++/cuda 代码
  • 同样用 profiler 测量执行时间,给出统计信息(算子名、总时间等)

7. 结果正确对比

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))
  • 比较 minimal_result(GPU kernel 版本)和 manual_result(纯 pytorch 版本)的数值是否在一定误差范围内一致
  • 这里设 rtol=0, atol=1e-2 说明只要它们的绝对误差不超过 1e-2 就认为结果接近,原因是浮点精度以及 flash attention 里分块累计时会有一些数值误差

总的来说 bench.py 脚本的作用是一个性能测试和正确性验证脚本:

  • 1. 先编译并导入我们的 c++/cuda 扩展模块(minimal_attn
  • 2. 用一个小模型规模生成随机数据 q, k, v
  • 3. 分别运行 self-attention 和 flash attention kernel,比较执行时间并检查输出差异

4. main.cpp

main.cpp 代码如下:

#include <torch/extension.h>

torch::Tensor forward(torch::Tensor q, torch::Tensor k, torch::Tensor v);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", torch::wrap_pybind_function(forward), "forward");
}

下面我们来逐代码分析:

1. 包含头文件

#include <torch/extension.h>
  • 这里引入了 pytorch c++ extension 提供的头文件,用于与 pytorch/pybind11 做交互
  • torch::TensorPYBIND11_MODULEtorch::wrap_pybind_function 等都来自这些接口

2. 声明外部函数

torch::Tensor forward(torch::Tensor q, torch::Tensor k, torch::Tensor v);
  • 这里只声明了一个名为 forward 的函数,它的具体实现在 flash.cu

3. 定义 pytorch extension 模块

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", torch::wrap_pybind_function(forward), "forward");
}
  • PYBIND11_MODULE 用于把 c++/cuda 代码打包成一个 python 模块,模块名即 TORCH_EXTENSION_NAME
  • 这里 m.def("forward", ...) 相当于给 python 端注册了一个函数,名字就叫 forward,当我们在 python 中调用 minimal_attn.forward(...) 时,底层就会调用到我们在 flash.cu 中定义的 c++ 函数 forward(...)
  • torch::wrap_pybind_function 是 pytorch 对 pybind11::function 的一层封装,用于让编译器自动帮我们传入/传出的 torch::Tensor 参数做 pytorch ⇔ \Leftrightarrow c++ 的转换

总的来说 main.cpp 的主要功能是桥接,它把 flash.cu 里的核心算法函数 forward(...) 注册为一个 python 模块的接口函数,供 python 调用

5. flash.cu

flash.cu 代码如下:

#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
                    const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
                    float* l, float *m, float* O) {
    int tx = threadIdx.x;
    int bx = blockIdx.x; int by = blockIdx.y;  // batch and head index

    // Offset into Q,K,V,O,l,m - different for each batch and head
    int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);  // gridDim.y = nh
    int lm_offset = (bx * gridDim.y * N) + (by * N);  // offset for l and m

    // Define SRAM for Q,K,V,S
    extern __shared__ float sram[];
    int tile_size = Bc * d;  // size of Qi, Kj, Vj
    float* Qi = sram;
    float* Kj = &sram[tile_size];
    float* Vj = &sram[tile_size * 2];
    float* S = &sram[tile_size * 3];

    for (int j = 0; j < Tc; j++) {

        // Load Kj, Vj to SRAM
        for (int x = 0; x < d; x++) {
            Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
            Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
        }
        __syncthreads();  // such that the inner loop can use the correct Kj, Vj

        for (int i = 0; i < Tr; i++)  {

            // Load Qi to SRAM, l and m to registers
            for (int x = 0; x < d; x++) {
                Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
            }
            float row_m_prev = m[lm_offset + (Br * i) + tx];
            float row_l_prev = l[lm_offset + (Br * i) + tx];

            // S = QK^T, row_m = rowmax(S)
            float row_m = -INFINITY;
            for (int y = 0; y < Bc; y++) {
                float sum = 0;
                for (int x = 0; x < d; x++) {
                    sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
                }
                sum *= softmax_scale;
                S[(Bc * tx) + y] = sum;

                if (sum > row_m)
                    row_m = sum;
            }

            // P = exp(S - row_m), row_l = rowsum(P)
            float row_l = 0;
            for (int y = 0; y < Bc; y++) {
                S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
                row_l += S[(Bc * tx) + y];
            }

            // Compute new m and l
            float row_m_new = max(row_m_prev, row_m);
            float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

            // Write O, l, m to HBM
            for (int x = 0; x < d; x++) {
                float pv = 0;  // Pij * Vj
                for (int y = 0; y < Bc; y++) {
                    pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
                }
                O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
                    * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
                    + (__expf(row_m - row_m_new) * pv));
            }
            m[lm_offset + (Br * i) + tx] = row_m_new;
            l[lm_offset + (Br * i) + tx] = row_l_new;
        }
        __syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop
    }
}

torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
    // TODO: determine Bc, Br dynamically
    const int Bc = 32; const int Br = 32;

    const int B = Q.size(0); const int nh = Q.size(1);
    const int N = Q.size(2); const int d = Q.size(3);

    const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
    const float softmax_scale = 1.0 / sqrt(d);

    // Initialize O, l, m to HBM
    auto O = torch::zeros_like(Q);
    auto l = torch::zeros({B, nh, N});
    auto m = torch::full({B, nh, N}, -INFINITY);
    torch::Device device(torch::kCUDA);
    l = l.to(device); m = m.to(device);

    // Calculate SRAM size needed per block
    const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
    int max_sram_size;
    cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
    printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);

    dim3 grid_dim(B, nh);  // batch_size x num_heads
    dim3 block_dim(Bc);  // Bc threads per block

    forward_kernel<<<grid_dim, block_dim, sram_size>>>(
        Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
        N, d, Tc, Tr, Bc, Br, softmax_scale,
        l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
    );
    return O;
}

flash.cu 中实现了所有 flash attention 的细节逻辑,包括分块 softmax、累加输出、共享内存加载、并行计算等

下面我们分 forward 接口函数和 forward_kernel 逻辑实现核函数两部分来讲解

5.1 forward函数

forward 函数是 forward_kernel 核函数的启动函数,其代码如下:

torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
    // TODO: determine Bc, Br dynamically
    const int Bc = 32; const int Br = 32;

    const int B = Q.size(0); const int nh = Q.size(1);
    const int N = Q.size(2); const int d = Q.size(3);

    const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
    const float softmax_scale = 1.0 / sqrt(d);

    // Initialize O, l, m to HBM
    auto O = torch::zeros_like(Q);
    auto l = torch::zeros({B, nh, N});
    auto m = torch::full({B, nh, N}, -INFINITY);
    torch::Device device(torch::kCUDA);
    l = l.to(device); m = m.to(device);

    // Calculate SRAM size needed per block
    const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
    int max_sram_size;
    cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
    printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);

    dim3 grid_dim(B, nh);  // batch_size x num_heads
    dim3 block_dim(Bc);  // Bc threads per block

    forward_kernel<<<grid_dim, block_dim, sram_size>>>(
        Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
        N, d, Tc, Tr, Bc, Br, softmax_scale,
        l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
    );
    return O;
}

下面我们来逐代码分析:

1. 函数签名和参数

torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V)
  • 这是在 c++/cuda 上实现的一个包装函数,直接对外暴露 pytorch tensor 接口
  • Q, K, V 是来自 pytorch 的张量,形状为 ( B , n h , N , d ) (B,nh,N,d) (B,nh,N,d)
    • B B B:batch size
    • n h nh nh:number of heads
    • N N N:序列长度
    • d d d:每个头的特征维度

此函数将返回一个和 Q 同形状的输出张量 O \mathbf{O} O

2. 函数体主要逻辑

2.1 确定分块大小 Bc, Br

// TODO: determine Bc, Br dynamically
const int Bc = 32; 
const int Br = 32;
  • 做法:硬编码设置块大小 B c = 32 , B r = 32 B_c=32,B_r=32 Bc=32,Br=32
  • 对应算法:这相当于在 flash attention 算法 步骤 1(分块大小设定)时,指定了列块大小、行块大小都为 32
  • 值得注意的是,通常实际实现中我们会根据硬件的 shared memory 大小、N、d 等情况动态决定最优块大小,这边简化成了固定值

2.2 获取张量形状并计算 Tc, Tr

const int B = Q.size(0);  // batch
const int nh = Q.size(1); // num_heads
const int N = Q.size(2);  // sequence length
const int d = Q.size(3);  // feature dim (per head)

const int Tc = ceil((float) N / Bc);
const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
  • 从 pytorch 的张量 Q 中获取四个维度:(B, nh, N, d)
  • Tc = ceil(N / Bc)Tr = ceil(N / Br)
    • 对应算法:这正是 步骤 3 所述,将键值矩阵 K , V \mathbf{K},\mathbf{V} K,V 沿序列维度分成 T c T_c Tc 块,将查询矩阵 Q \mathbf{Q} Q 沿序列维度分成 T r T_r Tr
  • softmax_scale = 1.0 / sqrt(d):缩放因子 1 d \frac{1}{\sqrt{d}} d 1

2.3 初始化输出张量 O \mathbf{O} O、向量 ℓ \ell 、向量 m m m

auto O = torch::zeros_like(Q);
auto l = torch::zeros({B, nh, N});
auto m = torch::full({B, nh, N}, -INFINITY);
torch::Device device(torch::kCUDA);
l = l.to(device); 
m = m.to(device);
  • 对应算法:这是 步骤 2 所描述的在片外内存 HBM 初始化输出矩阵 O \mathbf{O} O、归一化因子向量 ℓ \ell 和最大值向量 m m m
  • auto O = torch::zeros_like(Q);:创建一个和 Q \mathbf{Q} Q 同形状的全 0 张量,后面会在 kernel 中逐步累加注意力结果
  • auto l = torch::zeros({B, nh, N});:初始化 ℓ \ell 为 0
  • auto m = torch::full({B, nh, N}, -INFINITY);:初始化 m m m − ∞ -\infty
  • l = l.to(device);m = m.to(device);:将 l, m 放到 GPU 上

2.4 计算需要的共享内存大小并检查

// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \n",
        max_sram_size, sram_size);
  • 对应算法:是我们在 Require(算法基本设置)里所说的要确保片上内存 SRAM 足以容纳当前分块所需的 Q i , K j , V j , S i j \mathbf{Q}_i,\mathbf{K}_j,\mathbf{V}_j,\mathbf{S}_{ij} Qi,Kj,Vj,Sij 等中间结果
  • forward_kernel 中我们使用了 extern __shared__ float sram[]; 来动态分配 shared memory(on-chip SRAM)
  • 该内存需要容纳:
    • Kj, Vj, Qi 三个分块:Kj, Vj 分块各有 B c × d B_c\times d Bc×d 的大小,Qi 分块有 B r × d B_r \times d Br×d 的大小,总共 3 * Bc * d floats(本例中 B c = B r = 32 B_c=B_r=32 Bc=Br=32
    • S 矩阵: B c × B r B_c \times B_r Bc×Br 的大小
  • 通过 cudaDevAttrMaxSharedMemoryPerBlock 查询硬件支持的单个 block 最大共享内存 max_sram_size,用来做一个检查或调试信息打印

2.5 设置 kernel 的 gridDim 和 blockDim 并启动 kernel

dim3 grid_dim(B, nh);  // batch_size x num_heads
dim3 block_dim(Bc);    // Bc threads per block

forward_kernel<<<grid_dim, block_dim, sram_size>>>(
    Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
    N, d, Tc, Tr, Bc, Br, softmax_scale,
    l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
  • grid_dim(B, nh):2D-layout 网格维度,说明 GPU 上会有 B * nh 个线程块,每个 block 负责处理 一个 batch、一个 head 的全部序列 attention
  • block_dim(Bc):1D-layout 块维度,说明每个 block 里面有 Bc 个线程(本例中 Bc = 32),threadIdx.x 范围从 0 到 31
    • forward_kernel 里,我们将一整个 tile ( B c × d ) (B_c \times d) (Bc×d) 的数据加载到 shared memory,并让 Bc 个线程并行计算
    • 每个线程处理 1 个 token 即 1 × d 1\times d 1×d 维度的数据
  • sram_size:这是前面计算的 shared memory 大小,用来传给 <<<grid_dim, block_dim, sram_size>>> 作为第三个参数,表示动态分配多少 shared memory
  • 传入 kernel 的参数
    • Q.data_ptr<float>()K.data_ptr<float>()V.data_ptr<float>():指向 pytorch 张量 Q , K , V \mathbf{Q},\mathbf{K},\mathbf{V} Q,K,V 在 GPU 全局内存中的数据地址
    • N, d, Tc, Tr, Bc, Br, softmax_scale:各自维度和分块参数,以及 softmax 缩放因子
    • l.data_ptr<float>()m.data_ptr<float>()O.data_ptr<float>():同理,指向 ℓ , m , O \ell,m,\mathbf{O} ,m,O 在显存中的地址

2.6 返回输出结果

return O;
  • 返回输出张量 O \mathbf{O} O

总的来说,forward 函数在主机端(Host 端)做了三件事情:配置分块、分配输出与中间缓存O, l, m、启动 CUDA 核函数,在 forward_kernel 核函数内才真正完成了 Flash Attention 的分块计算,下面我们就来看看 forward_kernel 是怎么做的

5.2 forward_kernel函数

forward_kernel 核函数代码如下:

__global__
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
                    const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
                    float* l, float *m, float* O) {
    int tx = threadIdx.x;
    int bx = blockIdx.x; int by = blockIdx.y;  // batch and head index

    // Offset into Q,K,V,O,l,m - different for each batch and head
    int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);  // gridDim.y = nh
    int lm_offset = (bx * gridDim.y * N) + (by * N);  // offset for l and m

    // Define SRAM for Q,K,V,S
    extern __shared__ float sram[];
    int tile_size = Bc * d;  // size of Qi, Kj, Vj
    float* Qi = sram;
    float* Kj = &sram[tile_size];
    float* Vj = &sram[tile_size * 2];
    float* S = &sram[tile_size * 3];

    for (int j = 0; j < Tc; j++) {

        // Load Kj, Vj to SRAM
        for (int x = 0; x < d; x++) {
            Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
            Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
        }
        __syncthreads();  // such that the inner loop can use the correct Kj, Vj

        for (int i = 0; i < Tr; i++)  {

            // Load Qi to SRAM, l and m to registers
            for (int x = 0; x < d; x++) {
                Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
            }
            float row_m_prev = m[lm_offset + (Br * i) + tx];
            float row_l_prev = l[lm_offset + (Br * i) + tx];

            // S = QK^T, row_m = rowmax(S)
            float row_m = -INFINITY;
            for (int y = 0; y < Bc; y++) {
                float sum = 0;
                for (int x = 0; x < d; x++) {
                    sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
                }
                sum *= softmax_scale;
                S[(Bc * tx) + y] = sum;

                if (sum > row_m)
                    row_m = sum;
            }

            // P = exp(S - row_m), row_l = rowsum(P)
            float row_l = 0;
            for (int y = 0; y < Bc; y++) {
                S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
                row_l += S[(Bc * tx) + y];
            }

            // Compute new m and l
            float row_m_new = max(row_m_prev, row_m);
            float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

            // Write O, l, m to HBM
            for (int x = 0; x < d; x++) {
                float pv = 0;  // Pij * Vj
                for (int y = 0; y < Bc; y++) {
                    pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
                }
                O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
                    * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
                    + (__expf(row_m - row_m_new) * pv));
            }
            m[lm_offset + (Br * i) + tx] = row_m_new;
            l[lm_offset + (Br * i) + tx] = row_l_new;
        }
        __syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop
    }
}

下面我们来逐代码分析:

核函数签名如下:

__global__
void forward_kernel(const float* Q, const float* K, const float* V, 
                    const int N, const int d,
                    const int Tc, const int Tr, 
                    const int Bc, const int Br, 
                    const float softmax_scale,
                    float* l, float* m, float* O) {
    ...
}
  • Q, K, V:全局内存 HBM 中的输入查询、键、值矩阵指针
  • N, d:序列长度和特征维度
  • Tc, Tr:分块个数,对应外层和内存循环需要循环的次数,分别是 ⌈ N B c ⌉ \lceil \frac{N}{B_c} \rceil BcN ⌈ N B r ⌉ \lceil \frac{N}{B_r} \rceil BrN
  • Bc, Br:列块和行块的分块大小
  • softmax_scale:缩放因子 1 d \frac{1}{\sqrt{d}} d 1,在计算 S i j \mathbf{S}_{ij} Sij 时需要乘以这个缩放因子
  • l, m:分块存储的归一化因子 ℓ \ell 和最大值向量 m m m
  • O O O:最终输出的注意力结果,与 Q \mathbf{Q} Q 同形状

1. 线程、块、批次与多头索引

int tx = threadIdx.x;
int bx = blockIdx.x; 
int by = blockIdx.y;  // batch 和 head 的二维索引
  • tx 表示线程(thread)在线程块(block)内的一维索引,范围是 0 ~ (Bc - 1)
  • bx 是线程块在 x 方向上的索引
  • by 是线程块在 y 方向上的索引
  • 这里 dim3 grid_dim(B, nh),所以 bx 对应 batch 维度、by 对应多头维度

2. 计算在 Q/K/V/O/l/m 上的偏移量

int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); 
int lm_offset  = (bx * gridDim.y * N) + (by * N);
  • 由于在外层调用 kernel 时,网格大小设置为 (B, nh),其中 B 是 batch 数,nh 是 head 数,因此:
    • gridDim.y = nh
    • qkv_offset 用于在 (batch, head) 两个维度上,定位到 Q , K , V \mathbf{Q},\mathbf{K},\mathbf{V} Q,K,V 的正确起始地址
    • lm_offset 用于定位 ℓ , m \ell,m ,m 在全局内存中的偏移

3. 声明共享内存并分块

extern __shared__ float sram[];
int tile_size = Bc * d;  // 大小 = 每个 tile 中 (Bc 行) * d 列
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S  = &sram[tile_size * 3];
  • extern __shared__ float sram[]; 表示这个核函数会动态分配 shared memory(片上 SRAM),其大小在 <<<...>>> 调用时由第三个参数指定
  • 这里将这片 shared memory 切分成 4 个部分:
    • Qi:用于存放 Q i \mathbf{Q}_i Qi 这个 tile,大小为 ( B r × d ) (B_r\times d) (Br×d)
    • kj:用于存放 K j \mathbf{K}_j Kj 这个 tile,大小为 ( B c × d ) (B_c\times d) (Bc×d)
    • Vj:用于存放 V j \mathbf{V}_j Vj 这个 tile,大小为 ( B c × d ) (B_c \times d) (Bc×d)
    • S:用于存放分块注意力得分矩阵 S i j \mathbf{S}_{ij} Sij,大小为 B c × B r B_c \times B_r Bc×Br
  • 由于我们需要一次性把这几个子块都放到 shared memory 中做计算,因此必须分配足够大的 shared memory

4. 外层循环,遍历所有键值块

for (int j = 0; j < Tc; j++) {
    ...
}
  • 这对应算法中的 外层循环(步骤 5),也就是遍历所有的 K j \mathbf{K}_j Kj V j \mathbf{V}_j Vj 的列块

4.1 加载 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到共享内存

// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
    Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
    Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
  • 对应算法步骤 6:从片外 HBM 中加载当前列块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到片上共享内存
  • tile_size * j = (Bc * d) * j 表示第 j 个列块在全局 K/V 内存中的起始位置
  • tx 表示线程索引,这里假设每个线程拷贝一行的(即 Bc 行的其中一行)d 个元素,所以 Kj/Vj 的索引是 (tx * d) + x
  • __syncthreads(); 保证 block 内所有线程都完成数据加载后再进行后续操作,以确保共享内存里 Kj, Vj 是完整的

关于索引的计算博主绘制了一个草图来说明:

在这里插入图片描述

Note:我们以键矩阵为例来说明(值矩阵类似)

K K K 在内存中是连续存储的,数据总长度是 B * nh * N * d,现在我们的目的是找到全局 K K K 内存中某个 tile 块中的某个元素(图示红色小块)的全局位置索引

我们可以借用 NVIDIA 的 gridDim 和 blockDim 类似的布局思想,假设 HBM 中存储的 K K K 是以 2D-layout 布局,总共有 B * nh 个大块,每个大块的维度是 N * d

每个小块被切分成了 Tc 个 tile(本例中为 2),则图中红色小块的全局位置索引可以通过 qkv_offset + (tile_size * j) + (tx * d) + x 获取,其中:

  • qkv_offset 定位到 B * nh 个大块中具体哪个大块
  • tile_size * j 定位到 Tc 个 tile 中的哪个 tile
  • tx * d 定位到哪个 thread 处理
  • x 定位到属于 d 维度的哪个元素

5. 内层循环:遍历所有查询块

for (int i = 0; i < Tr; i++)  {
    ...
}
  • 对应算法步骤 7:遍历行块 Q i \mathbf{Q}_i Qi,每个 i 表示一个查询子块(行块),大小为 B r × d B_r \times d Br×d

5.1 加载 Q i \mathbf{Q}_i Qi O i \mathbf{O}_i Oi ℓ i , m i \ell_i,m_i i,mi

// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
    Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
  • 对应算法步骤 8:把当前的 Q i \mathbf{Q}_i Qi 分块加载进共享内存 Qi,同时从全局内存读入对应该行块的 m_il_i 到寄存器
  • 因为要对 ℓ i \ell_i i(softmax 归一化因子)和 m i m_i mi(最大值)做更新,需要先读旧的值。这里用 row_l_prevrow_m_prev 表示 ℓ i \ell_i i m i m_i mi 的旧值

Note:这里并没有显式加载 O i \mathbf{O}_i Oi 到共享内存,而是会在更新时直接访问全局内存的 O[...] 进行读写

5.2 计算分块注意力分数 S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i\mathbf{K}^T_j Sij=QiKjT

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
    float sum = 0;
    for (int x = 0; x < d; x++) {
        sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
    }
    sum *= softmax_scale;
    S[(Bc * tx) + y] = sum;

    if (sum > row_m)
        row_m = sum;
}
  • 对应算法步骤 9:计算 S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i\mathbf{K}^T_j Sij=QiKjT
  • 其中:
    • Qi[(tx * d) + x] 取到该线程处理的某行 Q i \mathbf{Q}_i Qi
    • Kj[(y * d) + x] 取到 K j \mathbf{K}_j Kj 的第 y
    • 然后逐元素相乘累加得到 sum
    • 乘以缩放因子 softmax_scale
    • 将结果写入共享内存 S[(Bc * tx) + y]
    • 这里也顺便计算行最大值 row_m,相当于 m ~ i j = r o w m a x ( S i j ) \tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) m~ij=rowmax(Sij)

5.3 做局部 softmax: P ~ i j = e x p ( S i j − m ~ i j ) \tilde{\mathbf{P}}_{ij} = \mathrm{exp}(\mathbf{S}_{ij}-\tilde{m}_{ij}) P~ij=exp(Sijm~ij),并求和

// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
    S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
    row_l += S[(Bc * tx) + y];
}
  • 对应算法步骤 10:对注意力分数做分块式 softmax:
    • 减去行最大值 row_m(即 m ~ i j \tilde{m}_{ij} m~ij)以做数值稳定
    • 指数化 exp ⁡ ( ⋅ ) \exp(\cdot) exp(),得到 P ~ i j \tilde{\mathbf{P}}_{ij} P~ij
    • 行向量求和,得到局部归一化因子 ℓ ~ i j \tilde{\ell}_{ij} ~ij,对应代码中的 row_l 变量

5.4 计算新的全局最大值与归一化因子

// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) 
                + (__expf(row_m - row_m_new) * row_l);
  • 对应算法步骤 11:计算新全局向量 m i n e w , ℓ i n e w m_i^{\mathrm{new}},\ell_i^{\mathrm{new}} minew,inew,参考公式:
    • m i n e w = max ⁡ ( m i , m ~ i j ) m_i^{\mathrm{new}}=\max(m_i,\tilde{m}_{ij}) minew=max(mi,m~ij)
    • ℓ i n e w = e m i − m i n e w ℓ i + e m ~ i j − m i n e w ℓ ~ i j \ell_i^{\mathrm{new}}=e^{m_i-m_i^{\mathrm{new}}}\ell_i+e^{\tilde{m}_{ij}-m_i^{new}}\tilde{\ell}_{ij} inew=emiminewi+em~ijminew~ij
  • 这里 row_m_prev 相当于上一次累积的 m i m_i mirow_m 是本块算出的 m ~ i j \tilde{m}_{ij} m~ij,分别对应公式中的 m i m_i mi m ~ i j \tilde{m}_{ij} m~ij
  • row_l_prev 是旧的 ℓ i \ell_i irow_l 是本块的 ℓ ~ i j \tilde{\ell}_{ij} ~ij
  • row_m_newrow_l_new 就是更新后的 m i n e w m_i^{\mathrm{new}} minew ℓ i n e w \ell_i^{\mathrm{new}} inew

5.5 更新输出 O i \mathbf{O}_i Oi

for (int x = 0; x < d; x++) {
    float pv = 0;  // Pij * Vj
    for (int y = 0; y < Bc; y++) {
        pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
    }
    O[qkv_offset + (tile_size * i) + (tx * d) + x] 
        = (1 / row_l_new) * (
             (row_l_prev * __expf(row_m_prev - row_m_new) 
               * O[qkv_offset + (tile_size * i) + (tx * d) + x])
           + (__expf(row_m - row_m_new) * pv)
          );
}
  • 对应算法步骤 12:分块累加更新 O i \mathbf{O}_i Oi,根据公式:
  • O i ← d i a g ( ℓ i n e w ) − 1 ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) \mathbf{O}_i\leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}(\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}_i+e^{\tilde{m}_{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j) Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijVj)
  • 对应到代码里:
    • pv 对应 P ~ i j V j \tilde{\mathbf{P}}_{ij}\mathbf{V}_j P~ijVj 逐元素相乘再累加
    • (row_l_prev * __expf(row_m_prev - row_m_new) * O[...] 对应旧的 O i \mathbf{O}_i Oi 部分乘上相应的因子 d i a g ( ℓ i ) e m i − m i n e w \mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}} diag(i)emiminew
    • __expf(row_m - row_m_new) * pv 对应新的部分
    • (1 / row_l_new) 则是对应乘上的 d i a g ( ℓ i n e w ) − 1 \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} diag(inew)1 这个归一化因子

5.6 写回 ℓ i , m i \ell_i,m_i i,mi

m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
  • 对应算法步骤 13:将新的 m i n e w , ℓ i n e w m_i^{\mathrm{new}}, \ell_i^{\mathrm{new}} minew,inew 写回全局内存 HBM
  • 完成对第 i 块(行块)的处理,内层循环推进到下一块
__syncthreads();
  • 对应算法步骤 14:内层循环结束(j++ 继续下一个 tile),用 __syncthreads() 确保当前的 Kj, Vj 处理完成

6. 外层循环结束

  • 对应算法步骤 15:当所有 j 都完成后,外层循环也结束,即完成了遍历所有 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj

总的来说,flash.cu 这份 CUDA 代码是一个比较直接的 flash attention 分块公式实现:每次加载一块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj,然后对所有 Q i \mathbf{Q}_i Qi 进行 partial softmax、partial output 累加,更新 O i , ℓ i , m i \mathbf{O}_i,\ell_i,m_i Oi,i,mi,当所有块都处理完后,就得到完整的注意力输出。

结语

flash-attention-minimal 中,最小化的 flash attention 实现过程还是非常清晰的,其大致思路是:

1. 先在 cpu/python 端根据 N , d N,d N,d 等参数设置好分块大小 B c , B r B_c,B_r Bc,Br,初始化 O , ℓ , m \mathbf{O},\ell, m O,,m

2. 在 gpu 端,使用外层循环遍历所有列块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj,并将它们加载到共享内层

3. 内层循环遍历所有行块 Q i \mathbf{Q}_i Qi,并同样加载到共享内层后,计算分块注意力分数 S i j \mathbf{S}_{ij} Sij,执行分块 softmax,更新全局归一化因子 ℓ \ell 与输出 O \mathbf{O} O

4. 循环完成后, O \mathbf{O} O 就是完整的注意力输出

代码里每一部分(加载、乘法、指数化、最大值即归一化更新、输出累加写回)都能和 flash attention 论文中的算法伪代码实现对应上,作为 flash attention 代码实现的入门还是非常不错的

大家感兴趣的可以看看,整个实现非常的简洁🤗

参考

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱听歌的周童鞋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值