🚀PyTorch 常用优化器总结(入门必备)
在用 PyTorch 写神经网络的时候,我们经常会看到下面这样的代码:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
一开始我是看不明白:
这 optimizer 是嘎哈的?怎么还有 SGD、Adam、RMSprop 一大堆的选项?到底整哪个好?
后来我自己动手实验了一圈,总算把这些优化器的种类整理的差不多了
🤔 啥是优化器?
一句话总结:
优化器就是用来根据“损失函数”调整网络参数的工具。
模型每一次训练,会算出一个损失(loss),然后把这个 loss 反向传播(.backward()
)出每个参数的梯度,优化器再根据梯度,来更新权重。
🔧 PyTorch 优化器我用过的几种
1️⃣ SGD:最基本的随机梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
- 我刚开始学的时候,用的就是这个。
- 每次都用当前的梯度,按学习率更新参数。
- 可以加 momentum(动量)让模型更快收敛、不抖动:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
🎯 适合传统图像任务,比如 LeNet、VGG,或者你刚上手神经网络时。
2️⃣ Adam:目前最常用、最智能的优化器之一
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
这个我在例子中都用到过,尤其是在跑 NLP 和图像分类时特别香。
特点是:
- 自动调整每个参数的学习率
- 不容易抖动,训练稳定
- 训练速度快
我现在用 transformer 模型,基本都是用的 Adam
或 AdamW
。
3️⃣ RMSprop:RNN 推荐使用的优化器
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01)
我之前做 LSTM 情感分类的时候试过这个,效果比 Adam 稍微差一点点,但也还不错。
4️⃣ Adagrad & Adadelta:对稀疏特征比较友好
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)
这两个优化器我用得比较少,主要是在看 NLP 文献时有人用。优点是对稀疏输入(比如词袋模型)特别有用(抄的话)。
5️⃣ AdamW:适合 BERT / GPT / ViT 等 transformer 系列
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
这个优化器是我最近才了解的,它是对 Adam 的升级版本,把 L2 正则化和优化解耦了。
几乎所有 transformer 模型(包括 HuggingFace)推荐用 AdamW
。
🎓 我是怎么选优化器的?
我根据自己的实践,总结出一个选择公式:
场景 | 推荐优化器 |
---|---|
图像分类(CNN) | SGD + momentum / Adam |
自然语言处理(NLP) | Adam / AdamW |
RNN、LSTM | RMSprop |
Transformer | AdamW |
模型很小 or 任务简单 | SGD 就够了 |