0

按照教程https://www.tensorflow.org/agents/tutorials/7_SAC_minitaur_tutorial,我想知道如何加载创建的检查点。

我正在尝试按照本教程https://www.tensorflow.org/agents/tutorials/10_checkpointer_policysaver_tutorial加载检查点,但我现在绝对知道如何使其与 SAC-Agent 一起使用。

检查点通过以下方式保存:

# Triggers to save the agent's policy checkpoints.
learning_triggers = [
triggers.PolicySavedModelTrigger(
    saved_model_dir,
    tf_agent,
    train_step,
    interval=policy_save_interval),
     triggers.StepPerSecondLogTrigger(train_step, interval=1000),
]

我成功地加载了政策:

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, env.time_step_spec(), env.action_spec())

因此可以创建一个执行策略的参与者:

time_step = env.reset()
initial_collect_actor = actor.Actor(
env,
eager_py_policy,
time_step,
steps_per_run=1,
)
initial_collect_actor.run()

不幸的是,我的模拟环境经常崩溃,所以我需要能够加载整个模型并从最后一个检查点继续训练。

4

0 回答 0