我正在使用 Ray RLlib 来训练 PPO 代理,并对 PPOTFPolicy 进行了两次修改。
- 我在“build_tf_policy()”中的“mixins”参数中添加了一个 mixin 类(比如“Recal”)。这样,PPOTFPolicy 将继承我的“Recal”类,并可以访问我在“Recal”中定义的成员函数。我的“Recal”类是 tf.keras.Model 的一个简单子类。
- 我定义了一个“my_postprocess_fn”函数来替换“build_tf_policy()”中参数“postprocess_fn”的“compute_gae_for_sample_batch”函数。
“PPOTrainer=build_trainer(...)”函数保持不变。我使用 framework="tf",并禁用了 Eager 模式。
伪代码如下。这是 colab 的运行版本。
tf.compat.v1.disable_eager_execution()
class Recal:
def __init__(self):
self.recal_model = build_and_compile_keras_model()
def my_postprocess_fn(policy, sample_batch):
with policy.model.graph.as_default():
sample_batch = policy.recal_model.predict(sample_batch)
return compute_gae_for_sample_batch(policy, sample_batch)
PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")
for i in range(1):
result = ppo_trainer.train()
这样,“Recal”类是 PPOTFPolicy 的基类,当创建 PPOTFPolicy 的实例时,“Recal”会在同一个 tensorflow 图中实例化。但是当 my_postprocess_fn() 被调用时,它会引发一个错误(见下文)。
tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]