3

我想使用一个未知作者提供的预训练张量流模型。我不知道他/她是如何设法将 tensorflow 模型(他/她使用 tensorflow 版本 >= 1.2)保存到一个扩展名为“.model”的文件中因为通常我得到三个文件“.meta”、“ .data'、'.index' 或一个带有 '.ckpt' 的文件。

如何恢复这个预训练模型?以后如何将模型保存为这种格式?

谢谢。

4

1 回答 1

2

我还在许多平台上问过这个问题,但还没有任何帮助。所以我决定做一些实验工作,这就是我发现的。这可能很长,但请多多包涵。

要在 Tensor-flow 中导入模型,我们使用

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))

.meta文件包含训练模型的所有变量、操作、集合等。什么tf.train.latest_checkpoint('./')是使用检查点文件(它只是记录保存的最新检查点文件)来导入xxxx_model.data-00000-of-00001. 这.data-00000-of-00001包含所有权重、偏差、梯度等,必须加载到my_test_model-1000.meta.

总结【半完整代码】

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    #new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    tensor_variable = tf.trainable_variables()
    for tensor_var in tensor_variable:
        #print(sess.run(tensor_var))
        print(tensor_var)

这个初始代码将打印出所有.meta可训练的变量。如果你尝试运行print(sess.run(tensor_var)),你会得到一个错误。这是因为,变量尚未初始化。但是,如果您取消注释new_saver.restore(sess, tf.train.latest_checkpoint('./'))并运行print(sess.run(tensor_var)),您将获得所有变量以及加载到变量中的值。


现在到“.model”</h2>

我最好的猜测是它的xxxxxx.model工作原理很像xxxx_model.data-00000-of-00001tensorflow。它不包含变量,因此如果您尝试这样做

with tf.Session() as sess:
          new_saver = tf.train.import_meta_graph('xxx.model')

你会得到一个错误。请记住,原因是,该.model文件不包含任何变量或任何形式的操作图。如果你也尝试做

with tf.Session() as sess:
          new_saver = tf.train.Saver()
          new_saver.restore(sess, "xxxx.model")

你同样会得到一个错误。这是因为,没有相应的变量可以加载值。因此,如果您曾经获得一个xxx.model文件,您将不得不在尝试运行之前经历复制所有变量和操作的痛苦new_saver.restore(sess, "xxxx.model")。如果您能够复制该架构,那么这将毫无问题地顺利运行,希望如此。

很抱歉这很长,但考虑到互联网上几乎没有答案,我不得不把它做一个演讲。:)

于 2018-04-19T04:36:56.917 回答