0

有了tensorflow,我做了一个dataset = tf.data.TFRecordDataset(filename)iterator = dataset.make_one_shot_iterator()。然后在每一轮中iterator.get_next()都会给出一小批数据作为输入。

我正在训练一个有Dropout层的网络,所以我应该写这样的东西:

sess.run(train_op,feed_dict={keep_prob:0.5})
accuracy,loss = sess.run([acc,loss],feed_dict={keep_prob:1.0})

其中keep_prob表示保持神经元存活的概率,这在训练和测试(这里是评估)过程中有所不同。

这里出现的问题是每个sess.run()触发器iterator.get_next()都会获得一批新的输入。这不是它应该的样子。

如果我想让这两个sess.run()具有相同的输入张量,我该怎么办?

非常感谢 :-)

4

1 回答 1

0

圣诞节快乐!

非常感谢圣诞老人的礼物 :-)

我刚刚被引导到这个地方,在那里你可以找到这个问题的答案。

主要思想是使用tf.data.Iterator.from_structure()而不是tf.data.Dataset.make_initializable_iterator()创建迭代器,并分别为训练、验证和测试数据集初始化迭代器。

圣诞快乐新年快乐!

于 2018-12-25T12:06:17.363 回答