3

在使用度量操作(例如来自 tf.python.ops.metrics 的准确度)训练图表后,我尝试恢复图表并评估测试集的准确度。但是,在使用 tf.import_meta_graph 恢复图形后,当我尝试使用 tf.local_variables_initializer() 初始化局部变量(这是必要的)时,出现错误,它说 'Tensor' 对象没有属性 'initializer'。

如果我在恢复后打印局部变量,有两个 Tensorflow 张量可能会导致问题。这两个 tensorlow 张量源于准确度指标:

  <tf.Tensor 'accuracy/total:0' shape=() dtype=float32_ref>
  <tf.Tensor 'accuracy/count:0' shape=() dtype=float32_ref>

有人可以帮我弄这个吗?谢谢!

类似代码:

def train():
  l_ini = np.array([1, 0, 1, 0, 1, 0], dtype=np.float32)
  p_ini = np.array([1, 0, 1, 0, 1, 1], dtype=np.float32)
  l = tf.Variable(l_ini, trainable=False)
  p = tf.Variable(p_ini, trainable=False)
  accuracy = metrics.accuracy(labels=l, predictions=p)
  tf.add_to_collection("accuracy", accuracy)

  graph = tf.get_default_graph()

  sess = tf.Session(graph=graph)
  sess.run(tf.global_variables_initializer())
  sess.run(tf.local_variables_initializer())
  acc = sess.run(accuracy)

  saver = tf.train.Saver()
  saver.save(sess, 'test.ckpt')

def restore():
  with tf.Session() as sess:
      loader = tf.train.import_meta_graph('./test.ckpt.meta')
      loader.restore(sess, './test.ckpt')
      accuracy = tf.get_collection("accuracy")

      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      acc = sess.run(accuracy)
4

1 回答 1

1

我有一个解决方法,而不是检索准确性集合(get_collection在我的情况下返回一个空列表):

  • 检索 logits 和 label 占位符。
  • 然后计算准确率。
  • 记得在恢复会话后初始化本地运行变量: self.running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="your_accuracy_scope_name")
于 2018-10-17T23:47:20.737 回答