【mmengine】优化器封装(OptimWrapper)(入门)优化器封装 vs 优化器

  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装(OptimWrapper)进行单精度训练混合精度训练梯度累加,对比二者实现上的区别。

一、 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

二、 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
在这里插入图片描述

三、 基于 Pytorch 的 SGD 优化器实现混合精度训练

在这里插入图片描述

  • 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
    • 内存占用更少
    • 计算更快
from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

四、 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述
只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

七、 获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)

print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}

在这里插入图片描述

八、 导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)

# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)



# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)

# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)

在这里插入图片描述

九、 使用多个优化器

OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nn

from mmengine.optim import OptimWrapper, OptimWrapperDict

# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)

# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)

# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)

# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)

print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}

在这里插入图片描述

### 关于 MMEngine 的详细介绍 #### 什么是 MMEngineMMEngine 是 OpenMMLab 推出的一个通用深度学习工具箱,旨在提供灵活且高效的模块化设计来支持多种下游任务。它不仅简化了模型开发流程,还提供了丰富的功能用于实验管理、日志记录以及分布式训练的支持[^3]。 #### 主要特点 - **灵活性**:通过高度解耦的设计允许用户自定义数据加载器、优化器和其他组件。 - **高效性**:内置对多 GPU 和分布式训练的良好支持,从而加速大规模模型训练过程。 - **易用性**:清晰直观的 API 设计降低了使用者的学习成本并提高了生产力。 --- ### 如何安装 MMEngine? 可以通过 Python 包管理工具 `pip` 来完成 MMEngine 及其依赖项的安装: ```bash pip install mmengine ``` 如果需要更高级的功能或者特定版本,则可以借助官方推荐的方式先安装必要的基础库如 PyTorch 和 CUDA 工具链后再继续操作: ```bash # 首先前置条件需满足PyTorch,CUDA等相关环境已正确设置好. pip install openmim mim install mmcv-full mim install mmcls ``` 以上命令确保了完整的生态体系被构建起来以便后续项目顺利运行[^2]. 对于 Windows 用户来说,在配置过程中可能会碰到一些兼容性和路径映射方面的问题;此时建议参照具体文档说明逐步排查解决可能存在的障碍。 --- ### 使用教程概览 以下是利用 MMEngine 构建简单分类网络的例子展示如何初始化模型结构及其参数更新机制等内容: ```python from mmengine.model import BaseModel import torch.nn as nn class SimpleCNN(BaseModel): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 32, kernel_size=3, stride=1) self.fc = nn.Linear(32 * 26 * 26, num_classes) def forward(self, imgs, labels=None): x = self.conv(imgs) x = x.view(x.size(0), -1) outputs = self.fc(x) if labels is not None: loss = F.cross_entropy(outputs, labels) return {'loss': loss} else: return {'pred':outputs.argmax(dim=-1)} ``` 上述代码片段展示了继承自 `BaseModel` 类的新类型对象创建方式,并实现了前向传播逻辑部分。 --- ### 总结 综上所述,MMEngine 不仅是一个强大的框架,而且拥有详尽的帮助资源可供查阅学习。无论是初学者还是资深开发者都能从中受益匪浅。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值