1

我正在使用 Chainer 训练(微调)Resnet 模型,然后使用检查点进行评估。检查点是一个 npz 文件,结构如下:

npz 检查点中的文件列表

当我加载模型进行评估时chainer.serializers.load_npz(args.load, model)(其中模型是标准 resnet),我收到以下错误:KeyError:'rpn/loc/b is not a file in the archive'。

我认为问题在于模型中的文件没有“updater/optimizer/faster/extractor”前缀。

如何更改生成的 npz 中的文件名称以删除前缀,或者我应该采取什么其他措施来解决问题?

谢谢!

4

1 回答 1

1

当您加载由 Snapshot Extension 生成的快照时,您需要从 trainer 中执行此操作。

chainer.serializers.load_npz(args.load, trainer)训练器会自动加载更新器、优化器和模型的状态。

您还可以通过访问快照中的相应字段并将其作为参数传递给model.serialize函数来手动仅加载模型

npz_data = numpy.load(args.load)
snap = chainer.serializers.NpzDeserializer(npz_data)
model.serialize(snap['updater']['model:main'])

这应该只加载模型的权重

于 2020-03-25T09:49:01.377 回答