1

假设你创建一个 tensorflow 会话,用权重训练一个网络,然后保存图形:

import tensorflow as tf
with tf.Session(graph=tf.Graph()).as_default() as sess:
    with sess.graph.as_default():
        some_var = tf.get_variable(name='foo', shape=(4))
        x = some_var + 1.0
        some_var.load([1, 2, 3, 4])
        # My understanding is that this saves the weights only:
        tf.train.Saver().save(sess, 'my/save/path')
        # My understanding is that this saves the graph structure (not sure if it saves the weights as well):
        graph_def = sess.graph.as_graph_def()

然后你想把它加载到的会话和一个的图表中。有几个原因需要这样做,例如,如果旧会话有大量后续图表添加,并且您想清除它以节省内存。另一个原因是,如果您的学习方法动态生成网络拓扑以最适合训练数据,那么不同的数据集将具有不同的结构。在这种情况下,很难简单地重新运行网络生成代码并加载一组权重。

with tf.Session(graph=tf.Graph()).as_default() as sess:
    with sess.graph.as_default():
        tf.import_graph_def(graph_def, name='')
        tf.train.Saver().restore(sess, 'my/save/path')  # error here

但是,当您尝试将其加载回来时,此代码会失败(尽管它抱怨正在保存):

ValueError: No variables to save

如何从 graph_def 和/或 tf.train.Saver() 文件加载新会话?

4

0 回答 0