0

在来自 stable baselines3 网站( https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html )的示例代码中,模型首先会通过model.learn(total_timesteps=25000)line 学习,然后可以在播放循环中使用.

现在,由于我希望能够在代理学习过程中监控不同的参数(来自自定义环境),所以我的问题是:如何model.learn在播放循环中使用?

import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")

del model # remove to demonstrate saving and loading

model = PPO.load("ppo_cartpole")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()
4

1 回答 1

0

用于训练的播放循环包含许多特定算法(例如 PPO)所需的各种操作。这种播放循环称为Rolloutscollect_rollouts您可以在 中找到推出功能stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm。因此,如果在框架中为您完成,最好不要编写自己的训练循环。

要跟踪各种参数(包括自定义参数),您可以查看回调https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html)。这可以包含在model.learn(timestamps=25000, callback=custom_callback). 此外,如果您只是想玩学习模型,您可以使用评估函数而不是使用相同的回调进行学习,以跟踪参数:

from stable_baselines3.common.evaluation import evaluate_policy
...
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
evaluate_policy(model.policy, env, n_eval_episodes=10, deterministic=True, callback=custom_callback)
于 2021-05-22T23:32:17.540 回答