-1

我正在尝试加载使用 Pytorch 训练过的模型,但我不断收到以下错误:

文件“convert.py”,第 12 行,在 model.load_state_dict(torch.load('model/model_vgg2d_2.pth')) 文件“/usr/local/lib/python3.5/dist-packages/torch/nn/modules /module.py",第 490 行,在 load_state_dict .format(name)) KeyError: 'unexpected key "module.features.0.weight" in state_dict'

下面是我的代码:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

我正在使用用于训练模型的同一台机器(具有多个 GPU)。任何想法我做错了什么?

4

1 回答 1

-1

加载时state_dict,您需要它是state_dict同一模型的 a :您不能state_dict将 VGG 模型的 a 加载到完全不同的BasicModel.


旧答案
您保存了模型而没有nn.DataParallel应用于模型,现在您在添加后尝试加载。尝试

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load
于 2018-10-23T09:05:19.913 回答