在使用度量操作(例如来自 tf.python.ops.metrics 的准确度)训练图表后,我尝试恢复图表并评估测试集的准确度。但是,在使用 tf.import_meta_graph 恢复图形后,当我尝试使用 tf.local_variables_initializer() 初始化局部变量(这是必要的)时,出现错误,它说 'Tensor' 对象没有属性 'initializer'。
如果我在恢复后打印局部变量,有两个 Tensorflow 张量可能会导致问题。这两个 tensorlow 张量源于准确度指标:
<tf.Tensor 'accuracy/total:0' shape=() dtype=float32_ref>
<tf.Tensor 'accuracy/count:0' shape=() dtype=float32_ref>
有人可以帮我弄这个吗?谢谢!
类似代码:
def train():
l_ini = np.array([1, 0, 1, 0, 1, 0], dtype=np.float32)
p_ini = np.array([1, 0, 1, 0, 1, 1], dtype=np.float32)
l = tf.Variable(l_ini, trainable=False)
p = tf.Variable(p_ini, trainable=False)
accuracy = metrics.accuracy(labels=l, predictions=p)
tf.add_to_collection("accuracy", accuracy)
graph = tf.get_default_graph()
sess = tf.Session(graph=graph)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
acc = sess.run(accuracy)
saver = tf.train.Saver()
saver.save(sess, 'test.ckpt')
def restore():
with tf.Session() as sess:
loader = tf.train.import_meta_graph('./test.ckpt.meta')
loader.restore(sess, './test.ckpt')
accuracy = tf.get_collection("accuracy")
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
acc = sess.run(accuracy)