1

我正在尝试将底层 tensorflow 模型从stable-baselines转换为tensorflowjs,以便能够在浏览器上使用该模型。但我无法进行转换

我按照这个 github 问题使用代码创建了必要的 tensorflow 文件:

def generate_checkpoint_from_model(model, checkpoint_name):  
        tf.saved_model.simple_save(model.sess, checkpoint_name, inputs={"obs": model.act_model.obs_ph}, outputs={"action": model.action_ph})

然后我尝试使用tensorflowjs_converter转换模型

tensorflowjs_converter --input_format=tf_saved_model test/ web_test

但是,它给了我以下错误:

Unable to lift tensor <tf.Tensor 'loss/action_ph:0' shape=(?,) dtype=int32> because it depends transitively on placeholder <tf.Operation 'loss/action_ph' type=Placeholder> via at least one path, e.g.: loss/action_ph (Placeholder)

我创建了以下带有错误的colab 笔记本,因此您可以尝试一下。

有谁知道如何使这种转换工作?

感谢您的帮助

4

1 回答 1

1

我将这个问题作为一个问题发布在 stable-baselines 中,他们回答了。我将在这里复制以供其他人参考:

您正在尝试保存 PPO 训练中使用的操作占位符(PPO 代理的一部分),但为了推断,您只需要经过训练的策略及其占位符 ( model.act_model)。通过将调用更改为此,colab 上的代码可以正常运行simple_save

tf.saved_model.simple_save(model.sess, checkpoint_name, inputs={"obs": model.act_model.obs_ph}, outputs={"action": model.act_model._policy_proba})

_policy_proba 的值取决于环境/算法。

于 2019-09-15T21:26:44.317 回答