4

我训练了一个我从这个存储库修改的香草 vae 。当我尝试使用经过训练的模型时,我无法使用load_from_checkpoint. 我的检查点对象和我的对象之间似乎不匹配lightningModule

我已经VAEXperiment使用pytorch-lightning LightningModule. 我尝试将权重加载到网络中:

#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()

#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])

我也试过:

checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

但我得到一个错误 Unexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias"......

我还在 https://github.com/PyTorchLightning/pytorch-lightning/issues/924 https://github.com/PyTorchLightning/pytorch-lightning/issues/2798上关注了这个问题

为什么我会收到此错误?是因为我的模型中的编码器和解码器模块吗?根据 git 上的问题日志,似乎错误已解决。我究竟做错了什么?

4

1 回答 1

3

从评论中发布答案:

experiment.load_state_dict(checkpoint['state_dict'])
于 2020-08-04T12:45:20.950 回答