多卡分布式训练时,模型中存在部分参数没有参与梯度更新,可以先看看是哪个模块,再定位具体的网络层,可以逐个关闭某个层的梯度进行排查
param.requires_grad = False
或者 torch.nn.parallel.DistributedDataParallel中加入find_unused_parameters参数并设置初始值为True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)