1

您好,我正在使用稳定的基线包(https://stable-baselines.readthedocs.io/),特别是我正在使用 PPO2 并且我不确定如何正确保存我的模型......我训练了 6 个虚拟天和让我的平均回报达到 300 左右,然后我认为这对我来说还不够,所以我又训练了 6 天的模型。但是当我查看训练统计数据时,每集的第二次训练返回开始于 30 左右。这表明它没有保存所有参数。

这就是我保存使用包的方式:

def make_env_init(env_id, rank, seed=0):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """

    def env_init():
        # Important: use a different seed for each environment
        env = gym.make(env_id, connection=blt.DIRECT)
        env.seed(seed + rank)
        return env

    set_global_seeds(seed)
    return env_init



envs = VecNormalize(SubprocVecEnv([make_env_init(f'envs:{env_name}', i) for i in range(processes)]), norm_reward=False)

if os.path.exists(folder / 'model_dump.zip'):
    model = PPO2.load(folder / 'model_dump.zip', envs, **ppo_kwards)
else:
    model = PPO2(MlpPolicy, envs, **ppo_kwards)

model.learn(total_timesteps=total_timesteps, callback=callback)
model.save(folder / 'model_dump.zip')

4

1 回答 1

1

您保存模型的方式是正确的。训练不是一个单调的过程:在进一步训练后,它也可以显示出更糟糕的结果。

你可以做的,首先是写进度日志:

model = PPO2(MlpPolicy, envs, tensorboard_log="./logs/progress_tensorboard/")

要查看日志,请在终端中运行:

tensorboard --port 6004 --logdir ./logs/progress_tensorboard/

它会给你一个板的链接,然后你可以在浏览器中打开它(例如http://pc0259:6004/

其次,您可以在每个 X 步中对模型进行快照:

from stable_baselines.common.callbacks import CheckpointCallback

checkpoint_callback = CheckpointCallback(save_freq=1e4, save_path='./model_checkpoints/')
model.learn(total_timesteps=total_timesteps, callback=[callback, checkpoint_callback])

将其与日志结合,您可以选择表现最佳的模型!

于 2020-04-06T18:59:07.790 回答