我想使用一个未知作者提供的预训练张量流模型。我不知道他/她是如何设法将 tensorflow 模型(他/她使用 tensorflow 版本 >= 1.2)保存到一个扩展名为“.model”的文件中,因为通常我得到三个文件“.meta”、“ .data'、'.index' 或一个带有 '.ckpt' 的文件。
如何恢复这个预训练模型?以后如何将模型保存为这种格式?
谢谢。
我想使用一个未知作者提供的预训练张量流模型。我不知道他/她是如何设法将 tensorflow 模型(他/她使用 tensorflow 版本 >= 1.2)保存到一个扩展名为“.model”的文件中,因为通常我得到三个文件“.meta”、“ .data'、'.index' 或一个带有 '.ckpt' 的文件。
如何恢复这个预训练模型?以后如何将模型保存为这种格式?
谢谢。
我还在许多平台上问过这个问题,但还没有任何帮助。所以我决定做一些实验工作,这就是我发现的。这可能很长,但请多多包涵。
要在 Tensor-flow 中导入模型,我们使用
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
该.meta
文件包含训练模型的所有变量、操作、集合等。什么tf.train.latest_checkpoint('./')
是使用检查点文件(它只是记录保存的最新检查点文件)来导入xxxx_model.data-00000-of-00001
. 这.data-00000-of-00001
包含所有权重、偏差、梯度等,必须加载到my_test_model-1000.meta
.
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
#new_saver.restore(sess, tf.train.latest_checkpoint('./'))
tensor_variable = tf.trainable_variables()
for tensor_var in tensor_variable:
#print(sess.run(tensor_var))
print(tensor_var)
这个初始代码将打印出所有.meta
可训练的变量。如果你尝试运行print(sess.run(tensor_var))
,你会得到一个错误。这是因为,变量尚未初始化。但是,如果您取消注释new_saver.restore(sess, tf.train.latest_checkpoint('./'))
并运行print(sess.run(tensor_var))
,您将获得所有变量以及加载到变量中的值。
我最好的猜测是它的xxxxxx.model
工作原理很像xxxx_model.data-00000-of-00001
tensorflow。它不包含变量,因此如果您尝试这样做
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('xxx.model')
你会得到一个错误。请记住,原因是,该.model
文件不包含任何变量或任何形式的操作图。如果你也尝试做
with tf.Session() as sess:
new_saver = tf.train.Saver()
new_saver.restore(sess, "xxxx.model")
你同样会得到一个错误。这是因为,没有相应的变量可以加载值。因此,如果您曾经获得一个xxx.model
文件,您将不得不在尝试运行之前经历复制所有变量和操作的痛苦new_saver.restore(sess, "xxxx.model")
。如果您能够复制该架构,那么这将毫无问题地顺利运行,希望如此。
很抱歉这很长,但考虑到互联网上几乎没有答案,我不得不把它做一个演讲。:)