有了tensorflow
,我做了一个dataset = tf.data.TFRecordDataset(filename)
和iterator = dataset.make_one_shot_iterator()
。然后在每一轮中iterator.get_next()
都会给出一小批数据作为输入。
我正在训练一个有Dropout
层的网络,所以我应该写这样的东西:
sess.run(train_op,feed_dict={keep_prob:0.5})
accuracy,loss = sess.run([acc,loss],feed_dict={keep_prob:1.0})
其中keep_prob
表示保持神经元存活的概率,这在训练和测试(这里是评估)过程中有所不同。
这里出现的问题是每个sess.run()
触发器iterator.get_next()
都会获得一批新的输入。这不是它应该的样子。
如果我想让这两个sess.run()
具有相同的输入张量,我该怎么办?
非常感谢 :-)