0

我在 ubuntu 18.04 上使用 torch==1.7.1

我想要做的是训练一个强化学习代理并将其发送到另一台服务器,因此训练有素的代理可以立即播放。

我一直在关注https://pytorch.org/tutorials/beginner/saving_loading_models.html但是需要澄清一下。

完成所有培训后,我通过以下方式保存了代理、优化器:

ckpt = {
    'Epoch': epoch,
    'model' : agent.state_dict(),
    'optimizer' : optimizer.state_dict()
}
torch.save(ckpt, filename)

现在,我已将此保存的文件发送到不同的服务器并加载它,如下所示:

ckpt = torch.load(ckpt_FILE_PATH)

agent = Agent()
optimizer = optim.Adam(agent.parameters(), lr=0.0005)

agent.load_state_dict(ckpt["agent"])
optimizer.load_state_dict(ckpt["optimizer"])

agent.eval()

我只想确定一件事。在训练阶段的代理实例和优化器创建期间,我使用实例化了优化器agent.parameters()

所以我的问题是我需要

  1. load_state_dict 到代理,然后使用 agent.parameters() 实例化优化器
  2. 只是在加载 load_state_dict 到代理之前实例化优化器?
  3. 没关系。

提前致谢。

4

0 回答 0