Transformer Arch Optimisations
Transformer Arch Optimisations
0"
#import "@preview/algo:0.3.0"
#import "@preview/diagraph:0.1.0"
#import "@preview/showybox:1.0.1"
#set page(
numbering: "1",
number-align: center,
header: align(right)[Transformer Architecture Optimizations],
)
== Introduction
Since their introduction in the seminal "Attention Is All You Need" paper by
Vaswani et al. in 2017, Transformer architectures have revolutionized natural
language processing and subsequently spread to dominate numerous other domains
including computer vision, audio processing, and multimodal learning. However, the
standard self-attention mechanism at the core of Transformers exhibits quadratic
complexity with respect to sequence length, creating significant computational
bottlenecks. This essay explores advanced optimization techniques for Transformer
architectures, with a particular focus on Flash Attention and other memory-
efficient attention implementations.
$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$
Where $Q, K, V \in \mathbb{R}^{n \times d}$ represent the query, key, and value
projections. The computational complexity of this operation is $O(n^2 d)$, with the
quadratic term becoming prohibitive for long sequences. Additionally, the standard
implementation requires storing the attention matrix $A = \text{softmax}\left(\
frac{QK^T}{\sqrt{d}}\right) \in \mathbb{R}^{n \times n}$ in memory, which becomes
infeasible for large $n$.
#showybox(
title: "Key Optimizations in Flash Attention",
frame: (
border-color: blue,
title-color: blue.darken(30%),
)
)[
1. *Block-wise computation*: Processing the attention matrix in tiles to fit in
fast GPU memory (SRAM)
2. *Recomputation during backpropagation*: Avoiding storage of intermediate
attention matrices
3. *Kernel fusion*: Combining multiple operations into single GPU kernels
]
#algo(
title: "Flash Attention Algorithm",
parameters: ("Q, K, V matrices", "block size B"),
line-numbers: true,
)[
#let S = vector(0, 0, 0) // Output accumulators
#let l = vector(0, 0, 0) // Row scaling factors
return S
]
Empirical evaluations show that Flash Attention can speed up Transformer training
by 2-4× while using significantly less memory, enabling longer context windows.
Several other techniques for optimizing Transformer attention have been developed:
Sparse attention patterns can reduce the computational complexity to $O(n \sqrt{n})
$ or even $O(n \log n)$. Prominent approaches include:
These approaches can be formalized as structured sparsity masks applied to the full
attention matrix:
$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} \odot M\
right)V
$
Where $M \in \{0, -\infty\}^{n \times n}$ represents the sparsity mask.
$
\text{LinearAttention}(Q, K, V) = \phi(Q)(\phi(K)^T V)
$
== Practical Implementations
```python
from flash_attn import flash_attn_qkvpacked_func
def flash_attention_forward(qkv):
"""
qkv: (batch_size, seqlen, 3, num_heads, head_dim)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
qkv = qkv.reshape(batch_size, seqlen, 3, -1)
output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0)
return output
```
== Conclusion
The interplay between algorithmic innovations like Flash Attention and hardware
acceleration will remain a key factor in the continued scaling of Transformer-based
AI systems, driving further advances in model capabilities and applications.