1

我正在使用 Ray RLlib 来训练 PPO 代理,并对 PPOTFPolicy 进行了两次修改。

  • 我在“build_tf_policy()”中的“mixins”参数中添加了一个 mixin 类(比如“Recal”)。这样,PPOTFPolicy 将继承我的“Recal”类,并可以访问我在“Recal”中定义的成员函数。我的“Recal”类是 tf.keras.Model 的一个简单子类。
  • 我定义了一个“my_postprocess_fn”函数来替换“build_tf_policy()”中参数“postprocess_fn”的“compute_gae_for_sample_batch”函数。

“PPOTrainer=build_trainer(...)”函数保持不变。我使用 framework="tf",并禁用了 Eager 模式。

伪代码如下。是 colab 的运行版本。

tf.compat.v1.disable_eager_execution()

class Recal:
    def __init__(self):
        self.recal_model = build_and_compile_keras_model()

def my_postprocess_fn(policy, sample_batch):
    with policy.model.graph.as_default():
        sample_batch = policy.recal_model.predict(sample_batch)
    return compute_gae_for_sample_batch(policy, sample_batch)

PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")

for i in range(1):
    result = ppo_trainer.train()

这样,“Recal”类是 PPOTFPolicy 的基类,当创建 PPOTFPolicy 的实例时,“Recal”会在同一个 tensorflow 图中实例化。但是当 my_postprocess_fn() 被调用时,它会引发一个错误(见下文)。

tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
     [[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
4

0 回答 0