我有一个自定义环境,它返回一个 dict 观察空间,如下所示:
OrderedDict([('achieved_goal', array([ 0.4008276 , -0.0685866 , -0.22774519, 0.05827878, 0.47759697,
0.7327185 , 2.4765387 , -0.8607227 , 0.89627784, -0.3062557 ,
-0.60894597, -1.4110374 ], dtype=float32)), ('desired_goal', array([-1.005679 , 0.34147817, 0.9540531 , 1.1987132 , 0.37403303,
0.32209057, 0.31095287, -2.1119647 , 0.82215786, -0.6675792 ,
-1.5640837 , 0.7348459 ], dtype=float32)), ('observation', array([-0.39490733, -0.67843455, -0.43765455, 0.1409685 , -0.67161006,
1.3106273 , 0.04009145, -1.714885 , -1.7085567 , -0.44895488,
-0.6111999 , -1.9730839 , 0.93647414, 0.2714189 , -0.67204314,
0.8948596 , -0.14034131, 1.0312599 , -1.2369561 , -0.2345652 ,
-0.17095046, 0.36576194, 0.9939435 , -1.0381949 , -1.2953175 ,
1.4120669 , -0.23294891, 0.30627772, -1.2250876 , -0.35871807,
1.3074456 , -1.060916 , -2.451866 , 0.18679707, 0.609564 ,
-0.16821782, -0.8448521 , -1.0025802 , 0.6878543 , -2.1562986 ,
0.6426088 , 1.386251 , 1.0454125 , -2.2426984 ], dtype=float32))])
但是,像 PPO 这样的算法不能使用字典空间。当我尝试过滤掉观察空间时,我得到如下错误:
我如何过滤掉:
env.observation_space = env.observation_space['observation']
错误跟踪:
Traceback (most recent call last):
File "PPO.py", line 69, in <module>
model.learn(total_timesteps=25000)
File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/ppo/ppo.py", line 289, in learn
reset_num_timesteps=reset_num_timesteps,
File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 220, in learn
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/base_class.py", line 379, in _setup_learn
self._last_obs = self.env.reset()
File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 62, in reset
self._save_obs(env_idx, obs)
File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 92, in _save_obs
self.buf_obs[key][env_idx] = obs
TypeError: float() argument must be a string or a number, not 'dict'