0

我想在每一集之后检索数据,我已经阅读了您可以使用的文档,stable_baselines3.common.monitor.ResultsWriter但我不知道如何在我的代码中实现它。

import gym
import numpy as np
import Neural_Traffic_Env

import stable_baselines3
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList, StopTrainingOnMaxEpisodes, EveryNTimesteps
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.monitor import Monitor, ResultsWriter

env = gym.make('NeuralTraffic-v1')
env = Monitor(env, filename="Monitor")

eval_callback = EvalCallback(env, best_model_save_path='./logs/best_model', log_path='./logs/', eval_freq=500)
checkpoint_callback = CheckpointCallback(save_freq=100, save_path='./saves/')
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=1000, verbose=1)
callback = CallbackList([callback_max_episodes, checkpoint_callback, eval_callback])

n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=1e6, log_interval=1, callback=callback)
model.save("ddpg")
env = model.get_env()

还有一个稳定的基线论坛我也可以提出我的问题吗?

4

0 回答 0