使用 torch.nn.DataParallel 训练和保存的模型,其 key 中多了’module’,在加载到单GPU或CPU环境中,会报错找不到key,需要将它去掉。
def load_model_without_module(model_state_file):
from collections import OrderedDict
new_checkpoint = OrderedDict()
checkpoint = torch.load(model_state_file)['state_dict']