梯度累积(gradient accumulation)是在训练模型时使训练的batch size大于机器的内存能够容纳的最大batch size时采用的一种方法。其实现是在多个更小的batch里去累加模型的梯度,只有达到想要的batch size后再用优化器更新模型的参数。比如本来想用大小为64的batch size来训练模型,但是机器最多只能使用大小为16的batch size来训练,此时可以设置gradient_accumulation_steps=4,batch_size=16,即让模型的梯度累积4次之后再更新模型的参数。
用pytorch来实现梯度累积的示意:
device = "cuda"
model.to(device)
## 设置累积步数
gradient_accumulation_steps = 2
for index, batch in enumerate(training_dataloader):
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss = loss / gradient_accumulation_steps ## scale
loss.backward()
## 只有gradient_accumulation_steps之后才更新模型参数
if (index + 1) % gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
我们可以使用简单的网络来验证梯度累加是有效的。
import torch
# super simple little MLP
net = torch.nn.Sequential(
torch.nn.Linear(16, 32),
torch.nn.GELU(),
torch.nn.Linear(32, 1)
)
torch.random.manual_seed(42)
## batch size 为8
batch_size = 8
x = torch.randn(batch_size, 16)
y = torch.randn(batch_size, 1)
net.zero_grad()
yhat = net(x)
loss = torch.nn.functional.mse_loss(yhat, y)
loss.backward()
## 打印梯度
print(net[0].weight.grad.view(-1)[:10])
#### 进行梯度累积将batch_size变成多个小的batch_size来运行模型
## 先将模型梯度清零
net.zero_grad()
## 梯度累积步数
gradient_accumulation_steps = 4
sub_batch = batch_size//gradient_accumulation_steps
for i in range(0, batch_size, sub_batch):
yhat = net(x[i:i+sub_batch])
loss = torch.nn.functional.mse_loss(yhat, y[i:i+sub_batch])
loss = loss /gradient_accumulation_steps #scale <-- have to add back the "normalizer"!
loss.backward()
## 打印梯度,这里的梯度与前面是一致的
print(net[0].weight.grad.view(-1)[:10])
但要注意的是梯度累积对于有些网络模型的效果不太好,比如使用了Batch Normalization的模型,因为此时网络的结果依赖于样本,梯度无法进行累加。
import torch
# super simple little MLP,网络中使用BatchNorm1d
net = torch.nn.Sequential(
torch.nn.Linear(16, 32),
torch.nn.BatchNorm1d(32),
torch.nn.Linear(32, 1)
)
torch.random.manual_seed(42)
## batch size 为8
batch_size = 8
x = torch.randn(batch_size, 16)
y = torch.randn(batch_size, 1)
net.zero_grad()
yhat = net(x)
loss = torch.nn.functional.mse_loss(yhat, y)
loss.backward()
## 打印部分梯度
print(net[0].weight.grad.view(-1)[:10])
### 测试梯度累加
net.zero_grad()
## 梯度累积步数
gradient_accumulation_steps = 4
sub_batch = batch_size//gradient_accumulation_steps
for i in range(0, batch_size, sub_batch):
yhat = net(x[i:i+sub_batch])
loss = torch.nn.functional.mse_loss(yhat, y[i:i+sub_batch])
loss = loss /gradient_accumulation_steps #scale <-- have to add back the "normalizer"!
loss.backward()
## 打印部分梯度,可以发现此时梯度与前面打印出来的梯度有差异
print(net[0].weight.grad.view(-1)[:10])
不考虑数值误差,对于梯度累加有效的模型,理论上来说梯度累加与用更大的batch size训练的效果是一样的,但是2024年10月trl有一个issue说微调LLM时用了梯度累加时的损失值更大,后面有人分析并修复了这个bug,写了一篇blog来解释这个问题:微调LLM时会出现梯度累加与原始batch size效果不一样是因为各个小batch里的样本平均长度可能是不一样的,因为pytorch的cross_entropy函数会考虑实际训练句子的长度(用ignore_index=-100来忽略掉padding字段对应的长度,transformers trainer在训练时会将样本处理成padding token对应的label为-100),就造成了梯度累积的损失要比同等大小的batch size更大的情况。
下面用gpt2来验证上述情况:
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
test_data = ["This is just as test."
,"The weather is very good today."
,"You are all resolved rather to die than to famish?"
,"We managed to formulate a new methodology that solves the issue."
,"Before we tried fixing the issue, could we first reproduce the error?"
,"We cannot, sir, we are undone already."
,"They are famish for food."
,"Here's the red poinsettia, one of the plants synonymous with the holiday season in the United States."
]
batch_size = 8
tokend_data = tokenizer(test_data, return_tensors='pt', padding=True)
x = tokend_data['input_ids'][:, :-1]
y = tokend_data['input_ids'][:, 1:]
## batch size为8,计算损失并查看最后一层的梯度(这里损失计算时不忽略填充的token)
model.zero_grad()
output1 = model(input_ids = x, attention_mask=tokend_data['attention_mask'][:, :-1])
loss = F.cross_entropy(output1.logits.view(-1, output1.logits.size(-1)), y.contiguous().view(-1))
print(loss)
loss.backward()
## 打印梯度
print(model.lm_head.weight.grad.view(-1)[:10])
## 梯度累积, 累积步数为4,计算损失并查看最后一层的梯度(这里损失计算时不忽略填充的token,意味着所有的子batch长度是一样的)
model.zero_grad()
total_loss = 0
## 梯度累积步数
gradient_accumulation_steps = 4
sub_batch = batch_size//gradient_accumulation_steps
for i in range(0, batch_size, sub_batch):
output2 = model(input_ids = tokend_data['input_ids'][i:i+sub_batch, :-1], attention_mask=tokend_data['attention_mask'][i:i+sub_batch, :-1])
loss = F.cross_entropy(output2.logits.view(-1, output2.logits.size(-1)), tokend_data['input_ids'][i:i+sub_batch, 1:].contiguous().view(-1))
loss = loss / gradient_accumulation_steps # <-- have to add back the "normalizer"!
total_loss += loss
loss.backward()
print(total_loss)
## 打印梯度,损失计算时不忽略填充的token,意味着所有的子batch长度是一样的,所以此时梯度累积的损失和梯度都与大batch size一样。
print(model.lm_head.weight.grad.view(-1)[:10])
######## 但是实际上训练计算损失时会将padding的token去掉,这里的gpt2的padding token id为50256,我们可以测试一下,发现这次累加梯度和原始梯度不一样了 #######
## batch size为8,计算损失并查看最后一层的梯度(这里损失计算时忽略padding token, cross_entropy设置了ignore_index)
model.zero_grad()
output1 = model(input_ids = x, attention_mask=tokend_data['attention_mask'][:, :-1])
loss = F.cross_entropy(output1.logits.view(-1, output1.logits.size(-1)), y.contiguous().view(-1)
,ignore_index=50256)
print(loss)
loss.backward()
## 打印梯度
print(model.lm_head.weight.grad.view(-1)[:10])
## 梯度累积
model.zero_grad()
total_loss = 0
gradient_accumulation_steps = 4
sub_batch = batch_size//gradient_accumulation_steps
for i in range(0, batch_size, sub_batch):
output2 = model(input_ids = tokend_data['input_ids'][i:i+sub_batch, :-1], attention_mask=tokend_data['attention_mask'][i:i+sub_batch, :-1])
loss = F.cross_entropy(output2.logits.view(-1, output2.logits.size(-1)), tokend_data['input_ids'][i:i+sub_batch, 1:].contiguous().view(-1), ignore_index=50256)
loss = loss / gradient_accumulation_steps # <-- have to add back the "normalizer"!
total_loss += loss
loss.backward()
print(total_loss)
## 打印梯度,损失计算时忽略padding的token,意味着子batch平均长度不一定相同,所以此时梯度累积的损失更大,梯度也不一样
print(model.lm_head.weight.grad.view(-1)[:10])
参考资料
- https://github.com/huggingface/trl/issues/2175 (文中图片来自此issue), https://unsloth.ai/blog/gradient
- accelerate梯度累加文档