前言
自2017年那篇石破天惊的论文《Attention Is All You Need》问世以来,Transformer架构便奠定了现代人工智能的基石。它彻底摒弃了循环和卷积,仅凭注意力机制就实现了并行化处理,极大地释放了模型的潜力。我们也在《从零训练大模型之模型搭建》这篇文章中,按照《Attention Is All You Need》的内容进行了模型实现,也训练出了一个pre-train模型。
然而,如果我们把2017年的原始Transformer比作一辆福特T型车,那么今天我们所熟知的Llama、Mistral等大语言模型(LLM)则更像是F1赛车。它们的外形或许依稀可见当年的影子,但内部的每一个核心部件都经历了翻天覆地的技术革命。
这场深刻的进化主要由两大驱动力催生:对极致效率的追求和对卓越性能的渴望。一方面,随着模型参数从数亿暴涨至数万亿,如何降低训练和推理的算力成本、减少内存占用,成为了一个生死攸关的工程问题。另一方面,用户对模型能力的要求也与日俱增,我们需要模型能够理解更长的上下文、生成更高质量的内容,并在各类任务上表现得更“聪明”。
这篇文章我们将深入探讨这几年来的技术革命。我们将逐一剖析那些对现代大模型产生深远影响的关键技术,不仅解释它们“是什么”,更会深入探讨“为什么”它们有效,并为此提供基于PyTorch的“如何实现”的代码。读完本文,您将清晰地了解以下核心技术如何重塑了今天的LLM:
-
FlashAttention:如何通过I/O感知优化,攻克注意力机制的内存墙。
-
Grouped-Query Attention (GQA):如何在保证质量的同时,为模型推理按下“加速键”。
-
Rotary Positional Embedding (RoPE):如何用旋转的艺术,让模型更优雅地理解位置关系。
-
RMSNorm & SwiGLU:看似微小的基础模块升级,如何带来显著的性能与效率提升。
-
Mixture of Experts (MoE):如何利用“稀疏激活”的哲学,构建万亿参数的巨无霸模型。
让我们一起踏上这段从经典Transformer到前沿大模型的探索之旅。
1. 注意力机制的效率革命
注意力机制是Transformer的心脏,但它 O ( N 2 ) O(N^2) O(N2)的复杂度也曾是其最大的缺点之一。有趣的是,研究者们发现,真正的瓶颈并不仅仅在于计算量,而在于更隐蔽的内存读写。针对训练和推理这两个不同场景,诞生了两种截然不同的优化思路。
1.1 FlashAttention:攻克内存墙
长久以来,人们普遍认为注意力机制的二次方复杂度是一个纯粹的计算(FLOPs)问题。然而,FlashAttention的提出者们敏锐地洞察到,对于现代GPU而言,真正的瓶颈在于内存I/O。
我们可以将GPU的内存系统想象成一个两层结构:一层是容量巨大但速度较慢的高带宽内存(HBM),好比一个大仓库;另一层是容量很小但速度快如闪电的片上SRAM,如同车间里的工作台。标准的PyTorch注意力实现非常“朴素”:它需要在HBM这个“大仓库”中完整地生成并存储一个巨大的N×N注意力分数矩阵(其中N是序列长度),然后再把它分块搬到SRAM这个“工作台”上进行计算。当N变得很大时(例如8192),这个中间矩阵会占用惊人的内存(例如,FP16精度下为8192×8192×2 bytes=128MB),频繁地在HBM和SRAM之间搬运数据,导致GPU的计算核心(“工人”)大部分时间都在“等米下锅”,这就是所谓的I/O瓶颈 。
FlashAttention的解决方案堪称绝妙,其核心思想是:彻底避免在HBM中写入和读出完整的注意力矩阵 。它通过两种经典的高性能计算技术实现了这一点:
-
Tiling(分块/瓦片化):FlashAttention将Query(Q)、Key(K)、Value(V)矩阵分解成更小的“块”(Tiles)。算法在SRAM中加载一小块Q,然后迭代地从HBM中加载一小块K和V。在SRAM内部,它计算这一小块Q与K的点积、应用Softmax、再乘以V,得到一个局部的注意力输出。这个过程会为下一块K、V的计算动态地更新一个归一化统计量。通过一种被称为“在线Softmax”(Online Softmax)的巧妙算法,它可以在不访问全局注意力分数的情况下,精确地计算出最终结果。
-
Kernel Fusion(核函数融合):上述整个分块计算过程,包括矩阵乘法、Softmax和掩码操作,被整合进一个单一的、高度优化的GPU计算核心(Kernel)中。这意味着所有中间结果都保留在极速的SRAM中,大大减少了与HBM之间的数据交换次数,从而让GPU的计算单元能够火力全开。
这项技术的意义是革命性的。它将注意力机制的内存占用从 O ( N 2 ) O(N^2) O(N2)降低到了 O ( N ) O(N) O(N),并带来了2-4倍的训练速度提升。此后,FlashAttention还在不断进化,FlashAttention-2和FlashAttention-3等版本针对NVIDIA Hopper等新一代GPU的硬件特性(如异步计算单元和FP8低精度格式)进行了更深度的优化,将硬件利用率推向了新的高峰。
幸运的是,我们普通开发者无需从头实现如此复杂的底层代码。从PyTorch 2.0开始,这一强大的技术已经通过torch.nn.functional.scaled_dot_product_attention(SDPA)
函数被无缝集成。当输入满足特定条件时(如在CUDA上,输入为FP16/BF16等),PyTorch会自动调用FlashAttention后端,让我们轻松享受到性能红利。
PyTorch实现
现代实用方法
在实际项目中,我们推荐直接使用PyTorch内置的SDPA函数。这是最简单、最高效的方式。
import torch
import torch.nn.functional as F
# 假设我们有Q, K, V,形状为 (batch_size, num_heads, seq_len, head_dim)
# 并且在CUDA上,数据类型为float16或bfloat16
q = torch.randn(2, 4, 1024, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 4, 1024, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 4, 1024, 64, device="cuda", dtype=torch.float16)
# PyTorch会自动选择最优的后端,在满足条件时就是FlashAttention
# is_causal=True 用于解码器中的因果掩码
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print("Output shape:", output.shape)
概念性简化实现
为了帮助初学者理解FlashAttention的算法思想,下面提供一个纯PyTorch的简化实现。请注意:此代码仅为教学目的,它不包含CUDA优化,运行速度会很慢,但它清晰地展示了分块计算和在线Softmax的核心逻辑。纯算法的实现,如果看起来吃力,可以先跳过。
import torch
import math
def simplified_flash_attention(q, k, v, block_size=128):
"""
一个纯PyTorch的FlashAttention前向传播简化实现,用于教学。
注意:这个实现非常慢,仅用于理解算法思想。
"""
# 假设我们有Q, K, V,形状为 (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = q.shape
scale = 1.0 / math.sqrt(head_dim)
output = torch.zeros_like(q)
# 初始化在线Softmax所需的统计量
# l_i 存储每个查询位置的softmax分母的当前估计值
# m_i 存储每个查询位置的行最大值的当前估计值
l = torch.zeros(batch_size, num_heads, seq_len, 1, device=q.device)
m = torch.full((batch_size, num_heads, seq_len, 1), -float('inf'), device=q.device)
# 将Q按行分块
for i in range(0, seq_len, block_size):
q_block = q[:, :, i:i+block_size, :]
# 对于每个Q块,我们需要迭代所有K,V块来计算精确的输出
# 初始化当前块的统计量
l_block = torch.zeros(batch_size, num_heads, q_block.shape, 1, device=q.device)
m_block = torch.full((batch_size, num_heads, q_block.shape, 1), -float('inf'), device=q.device)
output_block = torch.zeros_like(q_block)
# 将K,V按列分块
for j in range(0, seq_len, block_size):
k_block = k[:, :, j:j+block_size, :]
v_block = v[:, :, j:j+block_size, :]
# --- 核心计算步骤(在SRAM中进行)---
# 计算S_ij = Q_i * K_j^T
s_ij = torch.matmul(q_block, k_block.transpose(-1, -2)) * scale
# --- 在线Softmax更新 ---
# 找到新的行最大值
m_ij = torch.max(s_ij, dim=-1, keepdim=True)
m_new = torch.max(m_block, m_ij)
# 计算新的softmax分母
p_ij = torch.exp(s_ij - m_new)
l_new = torch.exp(m_block - m_new) * l_block + torch.sum(p_ij, dim=-1, keepdim=True)
# --- 更新输出 ---
# 对之前的输出进行缩放
output_block = output_block * (l_block / l_new) * torch.exp(m_block - m_new)
# 加上当前块的贡献
output_block += torch.matmul(p_ij, v_block)
# 更新统计量
m_block = m_new
l_block = l_new
# 将计算好的块输出写回HBM
output[:, :, i:i+block_size, :] = output_block
return output
# 测试
q_test = torch.randn(1, 1, 256, 64, device="cuda", dtype=torch.float32)
k_test = torch.randn(1, 1, 256, 64, device="cuda", dtype=torch.float32)
v_test = torch.randn(1, 1, 256, 64, device="cuda", dtype=torch.float32)
# 简化版FlashAttention
output_flash_simplified = simplified_flash_attention(q_test, k_test, v_test)
# 标准Attention作为对比
attn_scores = torch.matmul(q_test, k_test.transpose(-1, -2)) / math.sqrt(64)
attn_probs = F.softmax(attn_scores, dim=-1)
output_standard = torch.matmul(attn_probs, v_test)
# 验证结果是否接近
print("Outputs are close:", torch.allclose(output_flash_simplified, output_standard, atol=1e-5))
FlashAttention的成功标志着深度学习优化思路的一次重要转变:从单纯关注计算复杂度(FLOPs)转向了对I/O感知的算法设计。它揭示了一个深刻的道理:在现代硬件上,算法的性能不仅取决于它需要执行多少次乘法和加法,更取决于它如何与硬件的内存层级结构和谐共处。算法不再是脱离硬件的纯粹数学,而是与硬件深度耦合的工程艺术。这一理念为后续的许多优化工作铺平了道路。
1.2 Grouped-Query Attention (GQA):加速推理
FlashAttention主要解决了训练时的吞吐量问题,但在模型推理,尤其是自回归生成(即逐个token生成文本)的场景下,存在一个不同的瓶颈:KV缓存(KV Cache)。
在生成每个新token时,模型需要回顾(attend to)所有已经生成的token。为了避免重复计算,我们会将之前每个token的Key(K)和Value(V)向量缓存起来。随着生成序列的增长,这个KV缓存会越来越大。在每一步生成中,GPU都需要从HBM中读取完整的KV缓存,这成为了推理速度的主要限制因素。KV缓存的大小直接与K和V头的数量成正比。
为了解决这个问题,研究者们在原始的多头注意力(MHA)基础上,提出了一系列变体,形成了一个从追求极致质量到追求极致速度的谱系。
-
Multi-Head Attention (MHA):这是Transformer的原始设计。它有H个查询头(Query Head),同时也就有H个独立的键头(Key Head)和值头(Value Head)。这种设计提供了最强的模型表达能力,但其KV缓存也是最大的,导致推理速度较慢。
-
Multi-Query Attention (MQA):这是一种激进的优化策略。它仍然有H个查询头,但所有的查询头共享唯一的一对K/V头。这使得KV缓存的大小减小了H倍,极大地降低了内存带宽需求,从而显著提升了推理速度。然而,这种激进的共享也可能导致模型质量的明显下降和训练不稳定 。
-
Grouped-Query Attention (GQA):GQA是MHA和MQA之间的一个巧妙折中。它将H个查询头分成G个组(1 < G < H),每个组内的查询头共享一对K/V头。这样,总共就有G对K/V头。GQA提供了一个灵活的旋钮:当G=H时,它等价于MHA;当G=1时,它等价于MQA。通过选择合适的G值,GQA可以在接近MQA的速度下,达到几乎与MHA相媲美的模型质量 。如今,GQA已成为Llama、Mistral等主流开源模型的标准配置。
特性 | Multi-Head Attention (MHA) | Grouped-Query Attention (GQA) | Multi-Query Attention (MQA) |
---|---|---|---|
查询头数量 | H | H | H |
键/值头数量 | H | G (其中 1 < G < H) | 1 |
KV缓存大小 | 大 (与H成正比) | 中等 (与G成正比) | 小 (常数) |
推理速度 | 慢 | 快 | 最快 |
模型质量 | 高 | 接近高 | 可能下降 |
核心思想 | 每个查询头都有专属的K/V对 | 每组查询头共享一对K/V对 | 所有查询头共享唯一的K/V对 |
PyTorch实现
实现GQA的关键在于正确处理K和V的投影和形状。查询(Q)被投影到num_heads * head_dim的维度,而键(K)和值(V)则被投影到num_kv_heads * head_dim的维度。在计算注意力之前,需要将K和V的头重复num_groups次,以匹配Q的头的数量。
实现例子如下,也是算法层面的内容,加了很多注释,这块可以看看。
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, num_kv_heads, head_dim):
super().__init__()
self.num_heads = num_heads # Query头数量: 4
self.num_kv_heads = num_kv_heads # K/V头数量: 2
self.num_groups = num_heads // num_kv_heads # 每组Q头共享一个KV头的Q头数量: 4 // 2 = 2
self.head_dim = head_dim # 每个头的维度: 3
# Q, K, V的线性投影层
# 注意输出维度:
# q_proj 输出 (num_heads * head_dim) = (4 * 3 = 12)
# k_proj 输出 (num_kv_heads * head_dim) = (2 * 3 = 6)
# v_proj 输出 (num_kv_heads * head_dim) = (2 * 3 = 6)
self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=False) # 输出投影层
def forward(self, x, is_causal=False):
batch_size, seq_len, _ = x.shape
# 初始输入 x 形状: (bs, sl, embed_dim) -> (1, 2, 12)
# 1. 投影 Q, K, V
# Q: 将输入 x (1, 2, 12) 投影到 (1, 2, 12)
q = self.q_proj(x)
# K: 将输入 x (1, 2, 12) 投影到 (1, 2, 6)
k = self.k_proj(x)
# V: 将输入 x (1, 2, 12) 投影到 (1, 2, 6)
v = self.v_proj(x)
# 2. 调整形状以分离头
# 这是为了将扁平化的投影结果重新组织成 (头数量, head_dim) 的结构
# q: (bs, seq_len, num_heads, head_dim) -> (1, 2, 4, 3)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
# k: (bs, seq_len, num_kv_heads, head_dim) -> (1, 2, 2, 3)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# v: (bs, seq_len, num_kv_heads, head_dim) -> (1, 2, 2, 3)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# 3. GQA的核心:扩展K和V的头以匹配Q的头
# 这一步是GQA与MHA/MQA的关键区别。
# K和V现在只有 num_kv_heads = 2 个头,但Q有 num_heads = 4 个头。
# 我们需要让K和V的头数量也变成4,以便与Q进行点积运算。
# 但我们不是通过额外的投影层来增加头,而是通过重复已有的K/V头。
# (bs, seq_len, num_kv_heads, head_dim) -> (1, 2, 2, 3)
# k.unsqueeze(3): 在第3个维度 (num_kv_heads 后面) 增加一个维度,用于存放组信息。
# 新形状: (bs, seq_len, num_kv_heads, 1, head_dim) -> (1, 2, 2, 1, 3)
k = k.unsqueeze(3)
# k.expand(...): 将新增加的维度 (原来是1) 扩展为 num_groups (2)。
# -1 表示保持当前维度大小不变。
# 扩展后的形状: (bs, seq_len, num_kv_heads, num_groups, head_dim) -> (1, 2, 2, 2, 3)
k = k.expand(-1, -1, -1, self.num_groups, -1)
# 同理,v 也进行相同的操作: (1, 2, 2, 2, 3)
v = v.unsqueeze(3).expand(-1, -1, -1, self.num_groups, -1)
# 理解 expand 的效果:
# 假设原始 K 有 k_0, k_1 两个头。
# unsqueeze 后,k_0 -> k_0_single_group, k_1 -> k_1_single_group
# expand 后:
# k_0_group0, k_0_group1
# k_1_group0, k_1_group1
# 实际上,k_0_group0 和 k_0_group1 共享同一块内存,因为 expand 是视图操作,不复制数据。
# 合并kv头和组的维度,得到与q匹配的头数量
# 现在K和V的形状是 (bs, sl, num_kv_heads, num_groups, head_dim)
# 我们需要将其展平为 (bs, sl, num_heads, head_dim),其中 num_heads = num_kv_heads * num_groups
# k: (1, 2, 2, 2, 3) -> (1, 2, 4, 3)
k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
# v: (1, 2, 2, 2, 3) -> (1, 2, 4, 3)
v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
# 现在,Q, K, V 都具有 (1, 2, 4, 3) 的形状,可以进行多头注意力计算了。
# 重要的是,虽然K和V现在也有4个头,但实际上它们是基于原始2个K/V头复制而来的,
# 内存占用只与 num_kv_heads 相关,而不是 num_heads。
# 4. 为了使用PyTorch的SDPA (Scaled Dot Product Attention),需要将头维度移到前面
# PyTorch的SDPA期望输入形状是 (bs, num_heads, seq_len, head_dim)
# q: (1, 2, 4, 3) -> (1, 4, 2, 3)
q = q.transpose(1, 2)
# k: (1, 2, 4, 3) -> (1, 4, 2, 3)
k = k.transpose(1, 2)
# v: (1, 2, 4, 3) -> (1, 4, 2, 3)
v = v.transpose(1, 2)
# 5. 计算注意力
# F.scaled_dot_product_attention 自动处理缩放因子和softmax。
# 如果 is_causal=True,则应用因果掩码(对解码器有用)。
# output 形状: (bs, num_heads, seq_len, head_dim) -> (1, 4, 2, 3)
output = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
# 6. 合并头并进行最终投影
# 将头维度移回 (bs, seq_len, num_heads, head_dim) -> (1, 2, 4, 3)
output = output.transpose(1, 2)
# .contiguous() 是为了确保张量在内存中是连续的,以便 .view 操作能正确执行。
# .view(bs, sl, -1) 将 num_heads 和 head_dim 合并回 embed_dim
# (1, 2, 4, 3) -> (1, 2, 12)
output = output.contiguous().view(batch_size, seq_len, -1)
# 最终投影回原始的 embed_dim
# (1, 2, 12) -> (1, 2, 12)
return self.o_proj(output)
# 示例运行:
embed_dim = 12
num_heads = 4
num_kv_heads = 2 # 这是GQA的关键参数,它小于num_heads
head_dim = embed_dim // num_heads # = 3
gqa_layer = GroupedQueryAttention(embed_dim, num_heads, num_kv_heads, head_dim)
input_tensor = torch.randn(1, 2, embed_dim) # (bs, sl, embed_dim) -> (1, 2, 12)
output_tensor = gqa_layer(input_tensor)
print("Output shape:", output_tensor.shape) # 期望输出: torch.Size([1, 2, 12])
GQA的广泛应用揭示了另一个深刻的趋势:训练和推理的最佳架构正在发生分化。一个在并行化训练中能最大化模型表达能力的架构(如MHA),并不一定是序列化、内存受限的推理场景下的最高效选择。一些研究中提出的“uptraining”方法,即用少量算力将预训练好的MHA模型转换为GQA模型,正是这一趋势的有力证明 。这表明,现代LLM的设计已经演变成一个两阶段问题:第一阶段,在预训练中最大化学习能力;第二阶段,为高效部署进行优化。像GQA这样的技术,正是第二阶段问题的完美答案
2. 位置编码的革新:旋转的艺术(RoPE)
Transformer的原始设计中没有内置处理序列顺序的能力,因此需要引入位置编码(Positional Encoding)。最初的方法是将一个代表绝对位置的正弦/余弦信号向量加到每个词的嵌入向量上。这种方法虽然有效,但模型需要间接地从这些绝对位置信号中学习到相对位置关系。
现代LLM几乎一致地转向了一种更优雅、更强大的方案:旋转位置编码(Rotary Positional Embedding, RoPE)。
从加法到旋转
RoPE的核心思想是,不再将位置信息作为一种静态的“属性”加到词向量上,而是将其视为一种动态的“视角”来旋转查询(Q)和键(K)向量。
我们可以用一个几何直觉来理解它:
-
将每个Q/K向量的特征维度两两配对,想象成一系列二维平面上的点。例如,一个128维的向量可以看作64个二维坐标点。
-
对于一个在序列中绝对位置为m的token,RoPE会用一个特定的旋转矩阵 R m R_m Rm来旋转它对应的Q和K向量中的每一个二维坐标点。这个旋转矩阵的角度与位置m成正比,即旋转角度为 m θ i m\theta_i mθi,其中 θ i \theta_i θi是为第i个二维平面预设的基准角频率。
-
神奇之处在于点积运算。当计算位置为m的查询 q m q_m qm和位置为n的键 k n k_n kn之间的注意力分数时,它们的点积 ⟨ R m q m , R n k n ⟩ \langle R_m q_m, R_n k_n \rangle ⟨Rmqm,Rnkn⟩在数学上等价于一个仅依赖于相对位置(m-n)的变换作用于原始点积 ⟨ q m , k n ⟩ \langle q_m, k_n \rangle ⟨qm,kn⟩上。绝对位置m和n的信息在点积的几何关系中被“抵消”了,只留下了它们之间的相对关系。
这种将相对位置信息内嵌到注意力计算中的设计,是RoPE成功的关键。它为模型提供了一个极强的归纳偏置(inductive bias),使其能够天然地理解相对距离,而无需从绝对位置中费力学习。这解释了为什么RoPE在长度外推(length extrapolation)任务上表现出色:即使模型在训练时只见过长度为2048的序列,它也能很好地泛化到更长的序列上,因为相对位置的编码方式是不变的。RoPE已经成为Llama、Mistral等几乎所有现代主流LLM的基石。
这种将位置视为一种动态视角而非静态属性的理念,是思想上的一次深刻飞跃。它不仅是一种更好的编码技术,更是一种与注意力机制内在原理高度契合的信息呈现方式。这也解释了为什么RoPE的思想正在被探索应用于视觉和视频等其他领域,因为在这些领域中,空间和时间的相对关系同样至关重要 。
PyTorch实现
下面的代码从零开始实现了一个RoPE模块。它展示了如何预计算旋转频率、如何根据输入位置生成cos和sin缓存,以及如何将旋转应用于Q和K向量。
纯算法的实现,如果看起来吃力,可以继续先跳过。
import torch
import torch.nn as nn
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
# 预计算旋转频率 a.k.a. theta
# 论文中的公式是 theta_i = 10000^(-2(i-1)/d)
# 这里 i 的范围是 [1, 2,..., d/2]
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.cos_cached = None
self.sin_cached = None
def _update_cache(self, x, seq_len):
# 检查缓存是否需要更新
if self.cos_cached is not None and self.cos_cached.shape >= seq_len:
return
print(f"Updating RoPE cache for seq_len: {seq_len}")
# 计算位置 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
# 计算 t * theta
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# 得到 [t*theta_0, t*theta_1,..., t*theta_{d/2-1}]
# 将 freqs 扩展成 [t*theta_0, t*theta_0, t*theta_1, t*theta_1,...]
# 以便同时应用于 (x_0, x_1), (x_2, x_3),... 对
emb = torch.cat((freqs, freqs), dim=-1)
# 缓存cos和sin值
# 形状: (seq_len, dim) -> (seq_len, 1, 1, dim) 以便广播
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def rotate_half(self, x):
# 将最后一个维度对半分开
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
# 交换并取反其中一半
return torch.cat((-x2, x1), dim=-1)
def forward(self, q, k):
# q, k 的形状: (batch_size, num_heads, seq_len, head_dim)
seq_len = q.shape[-2]
self._update_cache(q, seq_len)
# 获取对应长度的缓存
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
# 应用旋转
# q_rot = q * cos + rotate_half(q) * sin
# k_rot = k * cos + rotate_half(k) * sin
q_rot = (q * cos) + (self.rotate_half(q) * sin)
k_rot = (k * cos) + (self.rotate_half(k) * sin)
return q_rot, k_rot
# 示例
rope = RotaryPositionalEmbeddings(dim=64)
q_vec = torch.randn(2, 4, 10, 64) # (bs, n_heads, seq_len, head_dim)
k_vec = torch.randn(2, 4, 10, 64)
q_rotated, k_rotated = rope(q_vec, k_vec)
print("Rotated Q shape:", q_rotated.shape)
print("Rotated K shape:", k_rotated.shape)
# 模拟推理时KV缓存的情况
q_infer = torch.randn(2, 4, 1, 64) # 新的单个查询
k_cache = k_rotated # 假设k_rotated是之前的KV缓存
# 在实际应用中,需要根据新token的位置来计算旋转
# Llama的实现中会传入一个位置id张量来处理这种情况
3. 基础模块升级
Transformer的革命不仅发生在高层的注意力结构,也深入到了构成每个Transformer块(Block)的最基础组件中。其中,归一化层(Normalization)和前馈网络(Feed-Forward Network, FFN)中的激活函数都经历了重要的迭代
3.1 RMSNorm:更精简、更快速的归一化
原始Transformer使用LayerNorm来稳定训练过程中的梯度和激活值 。LayerNorm的计算公式如下:
y = x − E [ x ] V a r [ x ] + ϵ ⋅ γ + β y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \cdot \gamma + \beta y=Var[x]+ϵx−E[x]⋅γ+β
它通过减去均值(re-centering)和除以标准差(re-scaling)来将层输入标准化。长期以来,人们认为这两个操作都是LayerNorm成功的关键。
然而,《Root Mean Square Layer Normalization》这篇论文提出了一个大胆的假设:re-centering操作可能是多余的。基于这个假设,RMSNorm诞生了。它极大地简化了公式,只保留了re-scaling部分:
y = x E [ x 2 ] + ϵ ⋅ γ y = \frac{x}{\sqrt{\mathrm{E}[x^2] + \epsilon}} \cdot \gamma y=E[x2]+ϵx⋅γ
这里的 E [ x 2 ] \sqrt{\mathbb{E}[x^2]} E[x2]就是输入的均方根(Root Mean Square)。通过移除均值计算和减法操作,RMSNorm在计算上变得更简单,从而更高效。实验证明,这种简化在性能上几乎没有损失,却能带来显著的速度提升(原论文报告了7%-64%的加速)。这一“少即是多”的哲学使其迅速成为Llama、Mistral等现代LLM的首选归一化方法。
RMSNorm的成功是一个通过“经验本质主义”推动进步的绝佳案例。它挑战了既有认知,通过实验证明了某个被认为是“核心”的组件其实可以被舍弃。有趣的是,多年后,有研究从几何角度为这一现象提供了理论解释,证明在训练好的LLM中,隐藏表示自然地分布在一个特定的子空间里,使得LayerNorm的均值减法操作在效果上变得冗余 。这个故事告诉我们,技术的进步不仅在于增加新的复杂性,更在于深刻理解并移除不必要的复杂性。
PyTorch实现
尽管PyTorch从2.1版本开始已经内置了torch.nn.RMSNorm
,但从零开始实现它能帮助我们更好地理解其原理。
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# gamma是可学习的缩放参数
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算均方根:sqrt(mean(x^2) + eps)
# keepdim=True 保证结果的维度可以进行广播
rms = (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
return x / rms
def forward(self, x):
# x的形状通常是 (batch_size, seq_len, dim)
# _norm(x) 的结果形状也是 (batch_size, seq_len, dim)
# self.weight 的形状是 (dim)
# 广播机制会自动将weight扩展到 (1, 1, dim) 与输出相乘
output = self._norm(x.float()).type_as(x)
return output * self.weight
# 示例
rmsnorm_layer = RMSNorm(dim=128)
input_tensor = torch.randn(2, 10, 128)
output_tensor = rmsnorm_layer(input_tensor)
print("Output shape:", output_tensor.shape)
# 使用PyTorch内置版本
pytorch_rmsnorm = nn.RMSNorm(128)
output_pytorch = pytorch_rmsnorm(input_tensor)
print("PyTorch RMSNorm output shape:", output_pytorch.shape)
3.2 SwiGLU:更智能的激活门控
原始Transformer的FFN层使用的是简单而经典的ReLU激活函数。然而,现代LLM几乎都转向了更复杂但性能更优的门控线性单元(Gated Linear Unit, GLU)变体,其中SwiGLU是应用最广泛的一种 。
与ReLU(xW + b)
这样单一的线性变换加非线性激活不同,GLU变体引入了门控机制。在Llama等模型的FFN层中,SwiGLU的实现通常涉及三个线性投影(我们称之为W1, W2, W3):
- 输入x首先经过两个并行的线性变换,得到 x W 1 xW_1 xW1和 x W 3 xW_3 xW3。
- 其中一个结果 x W 1 xW_1 xW1,被送入Swish激活函数(也称为SiLU),计算出 Swish ( x W 1 ) 。 S w i s h 函数的定义是Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(xW_1)。Swish函数的定义是\text{Swish}(x) = x \cdot \sigma(x) Swish(xW1)。Swish函数的定义是Swish(x)=x⋅σ(x),其中 σ \sigma σ是Sigmoid函数。这个结果就像一个“门控信号”。
- 这个门控信号与另一个线性变换的结果 x W 3 xW_3 xW3进行逐元素相乘。
- 最后,将门控后的结果通过第三个线性变换 W 2 W_2 W2输出。
整个公式可以写为:
FFN SwiGLU ( x ) = ( Swish ( x W 1 ) ⊙ x W 3 ) W 2 \text{FFN}_{\text{SwiGLU}}(x) = (\text{Swish}(xW_1) \odot xW_3)W_2 FFNSwiGLU(x)=(Swish(xW1)⊙xW3)W2
这里的 Swish ( x W 1 ) \text{Swish}(xW_1) Swish(xW1) 部分起到了一个数据依赖的门控作用。它不像ReLU那样是一个固定的、基于0的硬开关,而是一个平滑的、值域在0附近的连续函数。这个门控的值取决于输入x本身,它决定了xW_3$这条“信息通路”中的信息有多少能够被传递下去。这种动态的、上下文相关的控制能力,赋予了FFN层更强的表达能力,并被证实可以提升模型质量。
从ReLU到SwiGLU的转变,反映了LLM架构设计的一个更宏大的趋势:让模型的每一个计算单元都变得更加动态和数据依赖。ReLU的激活决策是静态的,而SwiGLU的门控决策是动态的、连续的、并且是模型学习到的。这与注意力机制使用数据依赖的权重来组合V向量、以及下一章将要介绍的MoE使用数据依赖的路由器来选择专家,在设计哲学上是一脉相承的。现代LLM正在从一个静态的数据处理流水线,演变为一个高度动态的、学会在内部自适应地控制信息流的复杂系统。
PyTorch实现
下面是一个典型的使用SwiGLU的FFN层的实现,它遵循了Llama等模型的结构。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256):
super().__init__()
# 隐藏层维度通常是dim的倍数,并向上取整到multiple_of的倍数
# 这是Llama等模型中的常见做法
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
# x 的形状: (batch_size, seq_len, dim)
# Swish(x * W1)
swish_gate = F.silu(self.w1(x))
# x * W3
value_path = self.w3(x)
# 逐元素相乘
gated_value = swish_gate * value_path
# 最终投影
return self.w2(gated_value)
# 示例
ffn_layer = SwiGLUFFN(dim=128, hidden_dim=256)
input_tensor = torch.randn(2, 10, 128)
output_tensor = ffn_layer(input_tensor)
print("Output shape:", output_tensor.shape)
4. 万亿参数的扩展:专家混合(MoE)范式
一个MoE层通常用来替代Transformer块中的标准FFN层。它的结构由两部分组成:
-
N个“专家”(Experts):这是一组并行的、独立的标准FFN网络。每个专家都可以被认为是某个领域的“专才”。
-
一个“门控网络”(Gating Network):也称为“路由器”(Router)。它的作用是为每一个输入的token,决定应该由哪些专家来处理。
其工作流程如下:
- 路由:对于序列中的每一个token,路由器会计算一个分数,评估该token与每个专家的“匹配度”。然后,它会根据分数选出Top-K个最匹配的专家(例如,在Mixtral中K=2)。
- 处理:只有被选中的这K个专家会被“激活”,并对该token进行计算。所有其他N−K个专家则保持“沉默”,不参与计算。
- 组合:该token的最终输出是这K个被激活专家的输出的加权和,权重也由路由器给出 。
MoE的巨大优势在于,模型可以拥有极大的总参数量(所有专家的参数总和),但在处理任何一个token时,只动用了其中一小部分(路由器 + K个专家)的参数。这使得计算成本保持在可控范围内,同时模型获得了巨大的容量。
当然,MoE也带来了新的挑战,其中最主要的是负载均衡(Load Balancing)。如果路由器倾向于总是将大部分token发送给少数几个“明星专家”,会导致这些专家过劳,而其他专家则被闲置,从而浪费了模型的容量。为了解决这个问题,训练MoE模型时通常会引入一个辅助损失函数(auxiliary loss),以激励路由器将token更均匀地分配给所有专家。
Mixtral 、Grok等模型的成功,已经证明了MoE是通往万亿参数规模的关键技术。
超越了单纯的规模扩展,MoE架构还代表了向模块化AI迈出的重要一步。一个密集模型是一个难以解析的“黑箱”,所有参数都纠缠在一起。而MoE模型在设计上就是模块化的,它拥有离散的“专家”组件。这自然引出了一系列新的研究方向:专家模型真的会自发学习到有意义的专长吗?早期的希望是它们能学习到人类可理解的语义领域(如“物理专家”、“历史专家”)。近期的研究表明,实际的路由决策更为复杂,可能依赖于位置、语法结构等多种因素。然而,这种潜力是巨大的。一些前沿工作,如Monet ,正试图通过专门的设计来训练出“单义性”(monosemantic)的专家。这预示着一个激动人心的未来:如果我们能理解并引导专家的特化,或许就能实现更高层次的模型可解释性和可控性。例如,我们可能不再需要对整个百亿参数模型进行微调来消除毒性,而是可以精准地定位并“再教育”或替换掉那几个负责生成有害内容的“坏专家”。
PyTorch实现
下面是一个简化的MoE层实现,它包含了路由和专家调度的核心逻辑。
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopKGate(nn.Module):
"""一个简单的Top-K门控网络"""
def __init__(self, input_dim, num_experts, k=2):
super().__init__()
self.k = k
self.gate_linear = nn.Linear(input_dim, num_experts)
def forward(self, x):
# x: (..., input_dim)
logits = self.gate_linear(x)
# 找到top-k的logit和索引
top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1)
# 对top-k的logit应用softmax得到权重
top_k_weights = F.softmax(top_k_logits, dim=-1)
# 创建一个稀疏的权重张量
full_weights = torch.zeros_like(logits)
full_weights.scatter_(-1, top_k_indices, top_k_weights)
return full_weights, top_k_indices
class MoELayer(nn.Module):
"""一个简化的MoE层"""
def __init__(self, input_dim, output_dim, num_experts, k=2):
super().__init__()
self.num_experts = num_experts
self.k = k
self.gate = TopKGate(input_dim, num_experts, k)
self.experts = nn.ModuleList()
def forward(self, x):
# x: (batch_size, seq_len, input_dim)
batch_size, seq_len, dim = x.shape
x = x.view(-1, dim) # (batch_size * seq_len, input_dim)
# 1. 获取门控权重和索引
gate_weights, top_k_indices = self.gate(x) # (bs*sl, n_exp), (bs*sl, k)
# 2. 高效路由
final_output = torch.zeros_like(x.unsqueeze(1).expand(-1, self.k, -1)) # (bs*sl, k, dim)
# 将输入x按照top_k_indices分配给专家
# dispatch_mask: (bs*sl, n_exp, k)
dispatch_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).permute(0, 2, 1)
for i in range(self.num_experts):
# 找到被分配给专家i的token
# expert_mask: (bs*sl, k)
expert_mask = dispatch_mask[:, i, :]
# token_indices: (num_tokens_for_expert_i)
token_indices = expert_mask.nonzero(as_tuple=True)
if token_indices.numel() > 0:
expert_input = x[token_indices]
expert_output = self.experts[i](expert_input)
# 将输出放回正确的位置
# (bs*sl, k, dim)
final_output.index_put_((token_indices, expert_mask[token_indices].argmax(dim=1)), expert_output)
# 3. 组合结果
# gate_weights: (bs*sl, n_exp) -> (bs*sl, k)
combine_weights = gate_weights.gather(dim=-1, index=top_k_indices)
# (bs*sl, k, dim) * (bs*sl, k, 1) -> (bs*sl, k, dim)
weighted_output = final_output * combine_weights.unsqueeze(-1)
# 对k个专家的输出求和
output = weighted_output.sum(dim=1) # (bs*sl, dim)
return output.view(batch_size, seq_len, -1)
# 示例
moe_layer = MoELayer(input_dim=128, output_dim=128, num_experts=8, k=2)
input_tensor = torch.randn(2, 10, 128)
output_tensor = moe_layer(input_tensor)
print("Output shape:", output_tensor.shape)
总结
我们从2017年的经典Transformer出发,一路走来,见证了一场深刻而全面的技术进化。这场革命并非单一技术的突破,而是一系列创新在不同层面协同作用的结果。为了应对效率和性能的双重挑战,LLM的架构师们像精密的钟表匠一样,对模型的每一个齿轮都进行了重构和优化。
- 为了突破训练吞吐量的极限,FlashAttention通过I/O感知的分块计算,让我们不再受制于内存墙。
- 为了加速推理过程,Grouped-Query Attention在模型质量和KV缓存大小之间找到了完美的平衡点。
- 为了让模型更深刻地理解序列顺序,RoPE用优雅的旋转代替了生硬的加法,赋予了模型天然的相对位置感知能力。
- 在最基础的计算模块中,RMSNorm和SwiGLU分别以更低的计算开销和更强的表达能力,取代了它们的前辈LayerNorm和ReLU。
- 为了实现前所未有的模型规模,Mixture of Experts范式通过稀疏激活,打破了参数量与计算成本之间的刚性耦合。
下表直观地展示了这场进化带来的“今昔之比”:
组件 | 原始Transformer (2017) | 现代LLM (如Llama/Mistral) |
---|---|---|
位置编码 | 加性正弦编码 | 旋转位置编码 (RoPE) |
注意力变体 | 多头注意力 (MHA) | 分组查询注意力 (GQA) |
注意力实现 | 朴素实现 (I/O受限) | I/O感知实现 (FlashAttention) |
FFN激活函数 | ReLU | SwiGLU |
归一化层 | LayerNorm | RMSNorm |
扩展策略 | 密集型 (Dense) | 稀疏型 (Mixture of Experts) |
对于我们初学者而言,理解这些技术不仅是追赶潮流,更是掌握构建未来模型的底层逻辑。I/O感知的设计、推理与训练的权衡、动态的计算机制、模块化的架构思想——这些原则不仅塑造了今天的LLM,也必将指引着明天的技术突破。
这一篇主要是介绍现代LLM的一些改进,下一篇将基于这些改进,修改我们之前的模型代码,并基于改进后的模型重新训练一个新的模型。
关注我的公众号不走丢
附录
GitHub链接:https://github.com/JimmysAIPG/MiniLLMs