0

在 ray rllib 中,我通常会应用 ray.tune.run 进行这样的 ppo 训练:

ray.init(log_to_driver=False, num_cpus=3, 
    local_mode=args.local_mode, num_gpus=1)
env_config={"code":"codeA"}
config={
 env_config={
     "code":"codeA"},
 "parm":"paramA"}
stop = {
    "training_iteration": args.stop_iters,
    "timesteps_total": args.stop_timesteps,
    "episode_reward_mean": args.stop_reward,
}
results = tune.run(trainer, config=config1, verbose=0, 
  stop=stop1, checkpoint_at_end=True,                               
  metric='episode_reward_mean', mode="max", 
  checkpoint_freq=1
             )

  checkpoints = results.get_trial_checkpoints_paths(
    trial=results.get_best_trial(
    metric='episode_reward_mean', 
    mode="max"),metric='episode_reward_mean')
  checkpoint_path = checkpoints[0][0]
  metric = checkpoints[0][1]

在下一轮,我通常使用这样的恢复检查点方法重新训练模型:

 results = tune.run('PPO', config=config1, verbose=0, 
      stop=stop, checkpoint_at_end=True,                                   
      metric='episode_reward_mean', mode="max", checkpoint_freq=1, restore=checkpoint_path)

推断:

agent = ppo.PPOTrainer(config=config1, env=env)
agent.restore(checkpoint_path=checkpoint_path)

这些流程奏效了。问题是(1):我是否可以在 ray.tune.run 结束时保存整个 pytorch 模型?(2) 下一轮ray.tune.run训练除了checkpoints恢复以外,可以导入pytorch模型吗?(3) 在推理阶段,如何将训练好的整个 pytorch 模型导入 PPO 代理?在恢复代理推理流程中,我一次不能将超过 1o 个模型加载到计算机内存中。大负载显示OOM问题。如果我一个一个地恢复一个模型,检查点恢复过程太耗时,不能满足时效性要求。谁能帮我?

4

1 回答 1

0

您可以在 tune.run() 中查看 keep_checkpoints_num 和 checkpoints_score_attr 以从此处自定义您想要多少个检查点 keep_checkpoints_num 的默认值为 None 因此它将存储所有检查点,但对于存储限制,您可以根据检查点仅保留顶部的检查点分数属性

于 2022-01-05T19:42:08.393 回答