0

我认为我有一个在 TensorFlow v1 上开发的项目。它在 Python 3.8 中像这样工作:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

检查点文件位于“checkpoint_dir”中

我想将它与 TFjs 一起使用,但我不知道如何将检查点文件转换为可以用 TFjs 加载的东西。

我应该怎么办?

谢谢,

约翰

4

1 回答 1

0

好的,我想通了。希望这对像我这样的其他初学者也有帮助。

检查点文件不包含模型,它们只包含模型的值(权重等)。

该模型实际上是在代码中构建的。因此,以下是将 Tensorflow v1 检查点文件转换为 TensorflowJS 可加载模型的步骤:

  1. 首先,我再次保存了检查点,因为缺少一个文件(.meta 文件),其中包含有关检查点中值的一些元信息。为了使用 meta 保存检查点,我在saver.restore(...调用后立即使用了此代码,如下所示:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. 将模型另存为冻结模型文件,如下所示:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

这会将模型保存到./freeze/output_graph.pb

  1. 使用tensorflowjs_converter将冻结模型转换为 Web 模型,如下所示:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check ./freeze/output_graph.pb ./web_model/

--skip_op_check由于尝试转换时缺少一些操作错误/警告而不得不使用。

作为第 3 步的结果,该./webmodel/文件夹将包含 TensorflowJS 库所需的 JSON 和二进制文件。

这是我使用 tfjs 2.x 加载模型的方式:

model=await tf.loadGraphModel('web_model/model.json');

于 2021-05-07T10:03:37.053 回答