目录
2.2 Standard Attention Implementation
3.FlashAttention: Algorithm, Analysis, and Extensions
3.1 An Efficient Attention Algorithm With Tiling and Recomputation
3.2 Analysis: IO Complexity of FlashAttention
3.3 Extension: Block-Sparse FlashAttention
4.1 Faster Models with FlashAttention
4.2 Better Models with Longer Sequences
5.Limitations and Future Directions
Abstract
Transformer 在长序列上很慢且需要大量内存,因为 Self-Attention 的时间和内存复杂度在序列长度上是 O(N^2) 的。近似 Attention 方法试图通过权衡模型质量来解决这个问题,以降低计算复杂度,但通常无法实现挂钟加速。我们认为缺失的原则是使注意力算法 IOaware - 考虑到 GPU 内存级别之间的读写。我们提出了 FlashAttention,这是一种 IO 感知的精确注意算法,它使用平铺来减少 GPU 高带宽内存 (HBM) 和 GPU 片上 SRAM 之间的内存读/写的数量。我们分析了 FlashAttention 的 IO 复杂度,表明它需要比标准注意力更少的 HBM 访问,并且对于一系列 SRAM 大小是最优的。我们还将 FlashAttention 扩展到块稀疏注意力,产生了比任何现有的近似注意力方法更快的近似注意力算法。FlashAttention 比现有基线更快地训练 Transformer:与 MLPerf 1.1 训练速度记录相比,BERT-large (seq. length 512) 的端到端挂钟加速 15%,GPT-2 上的加速 3 (seq. length 1K) 和远程竞技场上的 2.4 加速 (seq. length 1K-4K)。FlashAttention 和块稀疏 FlashAttention 在 Transformer 中实现了更长的上下文,产生了更高质量的模型 (GPT-2 的困惑度为 0.7,长文档分类提升 6.4 分) 和全新的能力:第一个 Transformer 在 Path-X 挑战 (seq. 长度 16K,61.4% 的准确率) 和 Path-256 (seq. 长度 64K,63.1% 的准确率)。
1.Introduction
Transformer 模型已成为自然语言处理和图像分类等应用中使用最广泛的架构。Transformers 的增长更大和更深,但将它们配备更长的上下文仍然很困难,因为其心脏的自注意力模块具有序列长度的二次时间和内存复杂度。一个重要的问题是,使注意力更快、内存效率更高,可以帮助 Transformer 模型解决它们对长序列的运行时和内存挑战。
许多近似注意方法旨在减少注意力的计算和内存需求。这些方法范围从稀疏近似到低秩近似及其组合。尽管这些方法减少了序列长度的线性或接近线性的计算要求,但它们中的许多并没有显示与标准注意力挂钟加速,也没有获得广泛采用。一个主要原因是他们专注于 FLOP 减少 (可能与挂钟速度无关),并且倾向于忽略内存访问 (IO) 的开销。
在本文中,我们认为缺失的原则是使注意力算法 IO 感知 - 即仔细考虑对不同级别的快速和慢速内存(例如,在快速 GPU 片上 SRAM 和相对较慢的 GPU 高带宽内存之间,或 HBM )。在现代GPU,计算速度已经超过了内存速度,Transformers 中的大多数操作都受到内存访问的瓶颈。IO 感知算法对于类似的内存绑定操作至关重要,当读取和写入数据时,可以解释大部分运行时 - 例如数据库连接 、图像处理 、数值线性代数等。然而,PyTorch 和 Tensorflow 等深度学习的常见 Python 接口不允许对内存访问进行细粒度控制。
我们提出了 FlashAttention,这是一种新的注意力算法,它以更少的内存访问计算准确的注意力。我们的主要目标是避免将 Attention 矩阵读写到 HBM 中。这需要在不访问整个输入的情况下计算 softmax reduction,而无需存储大的中间注意力矩阵以进行后向传递。我们应用两种成熟的技术来应对这些挑战。
(i) 我们重组注意力计算,将输入分成块,并对输入块进行多次传递,从而逐步执行 softmax 减少(也称为平铺)。
(ii) 我们从前向传递中存储 softmax 归一化因子,以在后向传递中快速重新计算片上的注意力,这比从 HBM 读取中间注意力矩阵的标准方法更快。
我们在 CUDA 中实现 FlashAttention,实现对内存访问的细粒度控制,并将所有 Attention 操作融合到一个 GPU 内核中。即使由于重新计算而导致的 FLOP 增加,我们的算法也运行得更快 (GPT-2 上高达 7.6 倍),并且由于 HBM 访问量大大减少,使用更少的内存 - 序列长度线性。
我们分析了 FlashAttention 的 IO 复杂度,其中 HBM 的复杂度为: 𝑂(𝑁^2 𝑑^2 𝑀^-1),其中 d 为 embedding 维度、M 为 SRAM 的大小,作为对比,标准 Attention 的复杂度为 Ω(𝑁 𝑑 + 𝑁^2)。与标准注意力相比,FlashAttention 的典型值需要的 HBM 访问次数要少得多 (最多减少 9 个,如图下图)。此外,我们提供了一个下限,表明没有确切的注意力算法可以渐近地提高所有 SRAM 大小的 HBM 访问次数。
我们还表明,FlashAttention 可以作为实现近似注意力算法潜力的有用原语,通过克服内存访问开销的问题。作为概念证明,我们实现了块稀疏 FlashAttention,这是一种稀疏注意力算法,比 FlashAttention 快,扩展到 64k 的序列长度。我们证明了块稀疏 FlashAttention 比 FlashAttention具有更好的 IO 复杂度,其因子与稀疏比成正比。我们在第 5 节中讨论了对其他操作 (多 GPU 上的注意力、内核回归、块稀疏矩阵 2 乘法) 的进一步扩展。我们开源 FlashAttention 使其更容易在此原语上构建。
我们凭经验验证 FlashAttention 通过建模更长的上下文来加速模型训练并提高模型质量。与之前的注意力实现相比,我们还对 FlashAttention 和块稀疏 FlashAttention 的运行时和内存占用进行了基准测试。
• 更快的模型训练。FlashAttention 在挂钟时间内更快地训练 Transformer 模型。我们在 MLPerf 1.1 、GPT2 (seq. length 1K) 中训练 BERT-large(seq. length 512)比 HuggingFace 和 Megatron-LM 的基线实现快 15%,远程竞技场(seq. length 1K-4K)比基线快 2.4x。
• 更高质量的模型。FlashAttention 将 Transformer 扩展到更长的序列,这提高了它们的质量并实现新功能。我们观察到 GPT-2 的困惑度提高了 0.7,在长文档分类上从建模更长的序列提升了 6.4 分。FlashAttention 使第一个 Transformer 能够在 Path-X 挑战上取得比机会更好的性能,仅使用更长的序列长度 (16K)。Block-sparse FlashAttention 使 Transformer 能够扩展到更长的序列 (64K),从而产生第一个可以在 Path-256 上获得更好的性能的模型。
• 对注意力进行基准测试。FlashAttention 在从 128 到 2K 的公共序列长度上比标准注意力实现快 3x,并扩展到 64K。直到序列长度为 512,FlashAttention 比任何现有的注意力方法更快、内存效率更高,而对于超过 1K 的序列长度,一些近似注意力方法(例如 Linformer)开始变得更快。另一方面,块稀疏 FlashAttention 比我们知道的所有现有的近似注意力方法更快。
2.Background
我们提供了一些关于现代硬件 (GPU) 上常见深度学习操作的性能特征的背景。我们还描述了注意力的标准实现。
2.1 Hardware Performance
[硬件性能]
我们在这里专注于 GPU。其他硬件加速器的性能相似。
GPU 内存层次结构。GPU内存层次结构 (上图) 包含多种形式的不同大小和速度的内存,内存更小。例如,A100 GPU 具有 40-80GB 的高带宽内存 (HBM),带宽为 1.5-2.0TB/s,每 108 个流式多处理器每 192KB 的片上 SRAM,带宽约为 19TB/s。片上 SRAM 比 HBM 快一个数量级,但尺寸小很多个数量级。随着计算相对于内存速度更快,操作越来越多地受到内存 (HBM) 访问的瓶颈。因此,利用快速 SRAM 变得更加重要。
执行模型。GPU 有大量线程来执行操作(称为内核)。每个内核将 HBM 的输入加载到寄存器和 SRAM 中,计算,然后将输出写入 HBM。
性能特点。根据计算和内存访问的平衡,操作可以分为计算绑定或内存限制。这通常用算术强度来衡量,它是每个字节内存访问的算术运算数。
- 计算绑定:操作所花费的时间由存在多少算术运算决定,而访问 HBM 的时间要小得多。典型的例子是矩阵乘以大的内部维度,以及具有大量通道的卷积。
- 内存限制:操作所花费的时间由内存访问的数量决定,而计算所花费的时间要小得多。示例包括大多数其他操作:elementwise (激活、dropout) 和减少 (sum、softmax、batchnorm、layernorm)。
内核融合。加速内存绑定操作最常见的方法是内核融合:如果对同一输入应用多个操作,则可以从 HBM 加载一次输入,而不是每次操作多次。编译器可以自动融合许多 elementwise 操作。然而,在模型训练的背景下,仍然需要将中间值写入 HBM 以节省后向传递,从而降低朴素内核融合的有效性。
2.2 Standard Attention Implementation
[标准 Attention 流程]
给定输入序列 Q/K/V ∈ R^{Nxd}, 其中 N 是序列长度,d 是头部维度,我们想要计算注意力输出 O ∈ R^{Nxd}:
其中 softmax 是按行应用的,一般计算时 dim=1。
标准注意力实现将注意力矩阵 S 和归一化权重 P 具体化为 HBM,这需要 O(N^2) 的空间。一般 N >> d,以 GPT2 为例其序列长度 N = 1024,Head 维度 d = 64。我们在算法 0 中描述了标准的注意实现。由于一些或大部分操作都是内存绑定的 (例如softmax),大量的内存访问转化为缓慢的挂钟时间。
其他应用于 Attention 矩阵的元素操作加剧了这个问题,例如应用于 S 的 Mask 或应用于 P 的 Dropout。因此,已经有很多尝试融合几个元素操作,例如将 Mask 与 softmax 融合。
在第 3.2 节中,我们将展示标准注意力实现在序列长度 N 中执行 HBM 访问二次方。我们还比较了标准注意力和我们的方法(FlashAttention)的 FLOP 和 HBM 访问次数。
Tips:
上面算法可以看到 Q/K/S/P/V 都经历了从 HBM 读取,计算,再写入 HBM 的操作,由于大部分是内存绑定,即内存读取时间 > 计算速度,因此内存访问是这里慢的主要原因。
3.FlashAttention: Algorithm, Analysis, and Extensions
[FlashAttention: 算法、分析和扩展]
我们展示了如何用更少的 HBM 读取/写入计算准确的注意力,而无需存储大型中间矩阵以进行后向传递。这产生了一种注意力算法,该算法在挂钟时间内既高效又更快。我们分析了它的 IO 复杂性,表明与标准注意力相比,我们的方法需要更少的 HBM 访问。我们进一步表明 FlashAttention 可以通过扩展它来处理块稀疏注意力来充当有用的原语。
为了便于说明,我们在这里专注于前向传递;附录 B 包含后向的详细信息。
3.1 An Efficient Attention Algorithm With Tiling and Recomputation
[一种高效的 Tiling 和 Recomput 注意算法]
给定输入 Q/K/V ∈ R^{Nxd},我们的目标是计算注意力输出 O ∈ R^{Nxd} 并将其写入HBM。我们的目标是减少 HBM 访问量 (低于 N 的二次即 sub-quadratic N)。
我们应用两种已建立的技术 (tiling 平铺、recomputation 重新计算) 来克服在次二次 HBM 访问中计算精确注意力的技术挑战。我们在算法 1 中描述这一点。其主要思想是,我们将输入Q/K/V 分成块,将它们从慢 HBM 加载到快速 SRAM 中,然后计算关于这些块的注意力输出。通过在添加每个块之前通过正确的归一化因子缩放每个块的输出,我们最终得到正确的结果。
Tiling 平铺。
我们通过块计算注意力。Softmax 对 K 的列进行耦合,因此我们用缩放分解大 softmax。对于数值稳定性,向量的 softmax ∈ R^B 通过下述方式计算:
对于序列 [x1, x2, ... , xB],m(x) 函数负责计算序列的最大值;f(x) 负责计算 safe exp,通过减去 m(x) 防止浮点溢出;l(x) 负责求和,softmax 遵循常规 Attention 的计算步骤,exp 后归一化。
对于分解的 x1,x2,我们可以分解连接的 softmax。假设 x = [x1, x2] ∈ R^{2B},此时:
m(x) = m([x1, x2]) = max(m(x1), m(x2)),f(x) 的计算做了拆分,分别做了一个缩放,其中缩放因子根据 m(x1)、m(x2) 与 m(x) 的差值而定,通过 exp 的加性进行还原,是的最终结果 Safe Attention 时 exp 指数部分减去的数字为 m(x),l(x) 则类似,通过全局 Sum 获取。
上面的算法如果我们跟踪一些额外的统计数据 m(x)、l(x),我们可以一次计算 softmax 一个块。我们将 Q/K/V 分为多个块,计算 softmax 值以及额外的统计数据,最终结合结果。
Recomputation 重新计算。
我们的目标之一是不存储 O(N^2) 后向传递的中间值。后向传递通常需要矩阵 S/P ∈ R^{NxN} 计算 Q/K/V 的梯度。然而,通过存储输出 O 和 softmax 归一化统计 (m,l),我们可以在 SRAM 中从 Q/K/V 的块向后传递中轻松重新计算注意力矩阵 S 和 P。这可以看作是一种选择性梯度检查点的形式。虽然梯度检查点已被证明可以减少所需的最大内存量,但所有实现都必须权衡内存的速度。相比之下,即使使用更多的 FLOP,由于 HBM 访问的减少,我们的重新计算也会加速后向传递。完整的后向传递描述在附录 B 中。
Implementation details 实施细节: Kernel fusion 内核融合。
Tiling 使我们能够在一个 CUDA 内核中实现我们的算法,从 HBM 加载输入,执行所有计算步骤(矩阵乘法、softmax、可选的掩码和 dropout、矩阵乘法),然后将结果写回 HBM(附录 B 中的掩码和 dropout)。这避免了反复阅读和写入输入和输出从 HBM 到 HBM。
我们展示了 FlashAttention 的正确性、运行时间和内存要求 (附录 C 中的证明)。
Theorem 1 定理:
算法 1 返回 O = Softmax(QK^T)V 需要 O(N^2d) 的 FLOPs 以及排除输入输出外 O(N) 的的额外空间。
3.2 Analysis: IO Complexity of FlashAttention
[分析:FlashAttention 的 IO 复杂性]
我们分析了 FlashAttention 的 IO 复杂性,与标准注意力相比,HBM 访问显着减少。我们还提供了一个下限,证明没有精确的注意力算法可以在所有 SRAM 大小上,HBM 访问的渐近改进。证明在附录 C 中。
Theorem 2 定理:
给定序列长度 N、头维度 d 以及 SRAM 的 size M,其满足 𝑑 ≤ 𝑀 ≤ 𝑁𝑑。标准 Attention 需要 O(Nd+N^2) 的 HBM 访问次数,Flash Attention 需要 O(𝑁^2·𝑑^2 / 𝑀) 的 HBM 访问次数。
对于 d(64-128) 和 M(大约100KB) 的典型值,d^2比 M 小很多倍,因此 FlashAttention 比标准实现需要更少的 HBM 访问次数。这导致了更快的执行和更低的内存占用,我们在第 4.3 节中对其进行了验证。
证明的主要思想是,给定大小为 M 的 SRAM,我们可以分别加载 O(M) 的 K/V 的块,对于每一个 K/V 块,我们遍历 Q 的所有块计算中间值,总共访问 𝑁𝑑 / 𝑀 次 Q,每次访问 𝑁𝑑 个元素,最终得到 O(𝑁^2·𝑑^2 / 𝑀) 次 HBM 的访问。我们同样证明了标准注意力的后向传递需要 O(𝑁 𝑑 + 𝑁^2) 次 HBM 的访问,而 FlashAttention 的反向传递需要 O(𝑁^2·𝑑^2 / 𝑀) 次 HBM 访问。
我们证明了一个下限:在计算精确注意力时,不能渐近地改进 M (SRAM 大小) 的所有值的 HBM 访问次数。
Proposition 3 命题:
给定序列长度 N、头维度 d 以及 SRAM 的 size M,其满足 𝑑 ≤ 𝑀 ≤ 𝑁𝑑。不存在一种算法来计算精确注意对于 M ∈ [d, Nd]。
证明依赖于这样一个事实,即对于 M = O(𝑁𝑑) 任何算法必须执行 Ω(𝑁^2·𝑑^2 / 𝑀) = Ω(𝑁𝑑) 次 HBM 的访问。在流算法文献中,M 的子范围上的这种类型的下界很常见。我们将证明参数化复杂度的下界 M 留作令人兴奋的未来工作。
我们验证了 HBM 访问的数量是注意力运行时间的主要决定因素。在 图 2-1 中,我们看到与标准注意力 (由于后向传递中的重新计算) 相比,FlashAttention 具有更高的 FLOP 计数,但它的 HBM 访问要少得多,从而导致更快的运行时间。在 图 2-2 中,我们改变了 FlashAttention 的块大小 Bc,这导致了不同数量的 HBM 访问,并测量了前向传递的运行时间。随着块大小的增加,HBM 访问的数量减少 (因为我们对输入进行更少的传递),运行时间减少。对于足够大的块大小 (超过 256),运行时间然后受到其他因素 (例如算术运算) 的瓶颈。此外,较大的块大小将不适合较小的 SRAM 大小。
3.3 Extension: Block-Sparse FlashAttention
[扩展: 块稀疏FlashAttention]
我们将 FlashAttention 扩展到近似注意力:我们提出了块稀疏 FlashAttention,其 IO 复杂度小于 FlashAttention,其因子与稀疏性成正比。
给定一个预定义的块稀疏掩码:
我们可以很容易地调整算法 1,只计算注意矩阵的非零块。该算法与算法 1 相同,只是我们跳过了零块。我们在附录 B 中重现算法 5 中的算法描述。
我们还分析了块稀疏 FlashAttention 的 IO 复杂度。
Proposition 4 命题:
给定序列长度 N、头维度 d 以及 SRAM 的 size M,其满足 𝑑 ≤ 𝑀 ≤ 𝑁𝑑。Block-sparse FlashAttention 需要 Θ(𝑁𝑑 + 𝑁^2 · 𝑑^ · s / 𝑀)。其中 s 为块稀疏掩码中非零块的比例。
我们看到,在 IO 复杂度中,将块稀疏性应用于更大项会产生直接的改进。对于大序列长度 N,s 一般设置为 N^{-1/2} 或者 N^{-1}logN,导致 Θ(N sqrt{N}) 和 Θ(N logN) 的 IO 复杂度。对于下游实验,我们使用固定的蝴蝶稀疏模式,已被证明能够近似任意稀疏性。
在 图 2-3 中,我们验证了随着稀疏性的增加,块稀疏 FlashAttention 的运行时间成比例地提高。在 LRA 基准测试中,块稀疏 FlashAttention 实现了 2.8 的加速,同时与标准注意力相当。
4.Experiments
我们评估了使用 FlashAttention 来训练 Transformer 模型的影响。我们验证了关于训练时间和模型精度的两个声明,并报告了注意力运行时和内存基准。
• 训练速度
与标准 Transformer 相比,FlashAttention 比 BERT 的 MLPerf 1.1 速度记录高出 15%,在 HuggingFace 上将 GPT-2 的速度速度提高了 3x,在 Megatron 上速度提高了 1x 倍以上。FlashAttention 加快了远程竞技场 (LRA) 基准 2.4x。
• 质量
FlashAttention 将 Transformer 扩展到更长的序列,从而产生更高质量的。FlashAttention 使用上下文长度为 4K 的 GPT-2 训练上下文长度为 1K 的 Megatron 训练 GPT-2,同时实现 0.7 更好的困惑度。对两个长文档分类任务进行建模会产生 6.4 分的提升。最后,FlashAttention 产生了第一个 Transformer,可以在具有挑战性的 Path-X 任务(序列长度 16K)上实现比随机更好的性能,并且块稀疏 FlashAttention 产生了我们知道这可以在 Path-256(序列长度 64K)上实现比随机更好的性能的第一个序列模型。
• 基准注意力
我们测量了基于序列长度的 FlashAttention 和块稀疏 FlashAttention 的运行时和内存性能。我们确认 FlashAttention 的内存占用与 seq 成线性关系。长度,并且比普通 seq 的标准注意力快 3 倍。长度(最多 2K)。我们确认块稀疏 FlashAttention 的运行时间在 seq 中线性缩放。长度,并且比所有现有的近似注意力基线更快。
额外的实验细节在附录 E 中。
4.1 Faster Models with FlashAttention
[带有 FlashAttention 的更快模型]
BERT
FlashAttention 产生了我们知道最快的单节点 BERT 训练速度。我们在 Wikipedia 上训练了一个带有 FlashAttention 的 BERT-large 模型。Table-1 将我们的训练时间与 Nvidia 的实现进行了比较,后者为 MLPerf 1.1 设置训练速度记录。我们的实现快 15%。Table-1:BERT-large 的训练时间,从 MLPerf 基准提供的相同初始化开始,在掩码语言建模上达到 72.0% 的目标准确率。在 8 个 A100 GPU 上平均超过 10 次运行。
GPT-2
与广泛使用的 HuggingFace 和 Megatron-LM 实现相比,FlashAttention 在大型 OpenWebtext 数据集上为 GPT-2 产生了更快的训练时间。Table-2 显示了与 Huggingface 相比高达 3x 的端到端加速,与 Megatron-LM 相比加速比高达 1.7x。FlashAttention 实现了与其他两种实现相同的困惑度,因为我们不更改模型定义。附录 E 包括在整个训练过程中验证困惑度图,证实 FlashAttention 与基线的数值稳定并产生相同的训练/验证曲线。Table-2: 与 Huggingface 实现相比,使用 FlashAttention 的 GPT-2 小培养基和培养基实现了高达 3x 的速度,与 Megatron-LM 相比高达 1.7x。8 个 A100s GPU 上报告的训练时间。
Long-range Arena
我们在远程竞技场 LRA 基准上比较了 vanilla Transformer (与标准实现或 FlashAttention)。我们测量所有模型的准确度、吞吐量和训练时间。每个任务都有不同的序列长度,在 1024 到 4096 之间变化。我们遵循 Tay 等人和 Xiong 等人的实现和实验设置。Table3.3 显示,与标准注意力相比,FlashAttention 实现了 2.4x 倍的加速。Block-sparse FlashAttention 比我们测试的所有近似注意力方法更快。Table-3: 在 Long-Range-Arena 基准上标准注意力、FlashAttention、块稀疏 FlashAttention 和近似注意力基线的性能。
4.2 Better Models with Longer Sequences
Language Modeling with Long Context
FlashAttention 的运行时和内存效率使我们能够将 GPT-2 的上下文长度增加 4,同时仍然比 Megatron-LM 的优化实现运行得更快。表 4 显示,具有 FlashAttention 和上下文长度为 4K 的 GPT-2 在上下文长度为 1K 的威震天仍然比 GPT-2 快 30%,同时实现了 0.7 更好的困惑度。
Long Document Classification
使用 FlashAttention 训练具有较长序列的 Transformer 可以提高 MIMIC-III 和 ECtHR 数据集上的性能。MIMIC-III 包含重症监护室患者出院摘要,每个摘要都用多个标签进行注释。ECtHR 包含法律案件欧洲人权法院,每一个都被映射到被指控为暴力的人权公约的文章。这两个数据集都包含非常长的文本文档; MIMIC 中的平均标记数为 2,395 个标记,最长的文档包含 14,562 个标记,而 ECtHR 中的平均和最长数字分别为 2,1971 和 49,392。我们从增加预训练的 RoBERTa 模型的序列长度来评估升力(我们重复位置嵌入,如 Beltagy 等人所示)。
Path-X and Path-256
Path-X 和 Path-256 基准是来自旨在测试长上下文的远程竞技场基准的具有挑战性的任务。该任务是对黑白 128 128(或 256 256 256)图像中的两点是否具有连接它们的路径进行分类,并一次将图像馈送到变压器一个像素。在之前的工作中,所有变压器模型要么耗尽内存,要么只实现了随机性能。已经搜索可以对此类长上下文进行建模的替代架构。我们在这里展示了 Transformer 模型能够解决 Path-X 和 Path-256 的第一个结果在表 6 中。我们在 Path-64 上预训练转换器,然后通过空间插值位置嵌入转移到 Path-X。FlashAttention 在 Path-X 上实现了 61.4 的准确度。此外,块稀疏 FlashAttention 使 Transformer 可以扩展到序列长度 64K,在 Path-256 上达到 63.1 的准确度 4。
4.3 Benchmarking Attention
[基准注意]
我们改变序列长度并测量 FlashAttention 和块稀疏 FlashAttention 在具有 40 GB HBM 的 A100 GPU 上的各种注意力基线的运行时和内存使用情况,带有 dropout 和填充掩码。我们比较了精确注意、近似注意和稀疏注意的参考实现。我们报告了正文中基线的子集;附录 E 包含更多基线和完整细节。
Runtime
图 3 (左) 报告了与精确、近似和稀疏注意力(附录 E 中的精确数字)的基线相比,FlashAttention 和块稀疏 FlashAttention 的前向 + 后向传递的毫秒运行时间(附录 E 中的精确数字)。运行时随序列长度呈二次增长,但 FlashAttention 运行速度明显快于确切的注意力基线,比 PyTorch 实现快 3 倍。许多近似/稀疏注意力机制的运行时间随序列长度线性增长,但由于内存访问较少,FlashAttention 仍然比短序列的近似和稀疏注意力运行得更快。近似注意力运行时开始在 512 到 1024 的序列上使用 FlashAttention 交叉。另一方面,块稀疏 FlashAttention 在所有序列长度中都比我们知道的确切、稀疏和近似注意力的所有实现更快。
Memory Footprint
图 3 (右) 显示了与各种精确、近似和稀疏注意力基线相比,FlashAttention 和块稀疏 FlashAttention 的内存占用。FlashAttention 和块稀疏 FlashAttention 具有相同的内存占用,随着序列长度线性增长。FlashAttention 比精确的注意力基线具有更高的内存效率,并且比近似注意力基线具有更高的内存效率。除了 Linformer 在 64K 之前在 A100 GPU 上运行内存之外的所有其他算法,FlashAttention 仍然比 Linformer 更有效 2x。
5.Limitations and Future Directions
[局限性和未来方向]
我们讨论了我们的方法和未来方向的局限性。相关工作在附录 A 中给出。
Compiling to CUDA
编译到 CUDA。我们目前构建注意力 IO 感知实现的方法需要为每个新的注意力实现编写一个新的 CUDA 内核。这需要在比 PyTorch 低得多的语言中编写注意力算法,并且需要显着的工程工作。实现也可能不能跨 GPU 架构转移。这些限制表明需要一种支持在高级语言(例如 PyTorch)中编写注意力算法并编译到 CUDA 中的 IO 感知实现的方法——类似于图像处理中的 Halide 等努力
IO-Aware Deep Learning
IO-Aware 深度学习。我们相信 IO 感知方法可以扩展到注意力之外。Attention 是 Transformers 中内存密集型最多的计算,但深度网络中的每一层都涉及 GPU HBM。我们希望我们的工作能够激发其他模块的 IO 感知实现。我们在附录 D 中讨论了这些潜在的扩展。
Multi-GPU IO-Aware Methods
多 GPU IO 感知方法。我们的 IO 感知注意力实现在用于计算单个 GPU 上的注意力的常数内是最优的。然而,注意力计算可以跨多个 GPU 并行化。使用多个 GPU 为 IO 分析添加了一个额外的层 - 考虑了 GPU 之间的数据传输。我们希望我们的工作能够启发未来的方向未来的工作。
Tips:
Flash Attention 核心思想是通过利用 SRAM 的高效计算,同时减少对 HBM 的访问次数,论文的附录中给出了详细的推理证明以及相关算法的实现细节,有需要的同学可以查阅原文学习:
https://arxiv.org/abs/2205.14135
后面我们也会基于上面的 Softmax 计算方法做 torch 代码的简单示例,欢迎大家持续关注~