0

我的数据在一个tfrecords文件中。tf.data.Dataset这个简单的代码使用api迭代和批处理图像。然而,每 100 个批次的计算时间增加了。为什么会这样以及如何解决这个问题?

import tensorflow as tf
import time
sess = tf.Session()
dataset = tf.data.TFRecordDataset('/tmp/data/train.tfrecords')
dataset = dataset.repeat()
dataset = dataset.batch(3)
iterator = dataset.make_one_shot_iterator()

prev_step = time.time()
for step in range(10000):
    tensors = iterator.get_next()
    fetches = sess.run(tensors)
    if step % 200 == 0:
        print("Step %6i time since last %7.5f" % (step, time.time() - prev_step))
        prev_step = time.time()

这将输出以下时间:

Step      0 time since last 0.01432
Step    200 time since last 1.85303
Step    400 time since last 2.15448
Step    600 time since last 2.65473
Step    800 time since last 3.15646
Step   1000 time since last 3.72434
Step   1200 time since last 4.34447
Step   1400 time since last 5.11210
Step   1600 time since last 5.87102
Step   1800 time since last 6.61459
Step   2000 time since last 7.57238
Step   2200 time since last 8.33060
Step   2400 time since last 9.37795      

tfrecords 文件包含 MNIST 图像,使用来自 Tensorflow 文档的 HowTo编写

为了缩小问题范围,我复制了从磁盘读取原始图像的代码。在这种情况下,每 200 个批次的时间按预期保持不变。

现在我的问题是:

  • 代码的哪一部分增加了计算时间?
  • 我应该将此作为 TensorFlow github 中的错误提交吗?

解决了!

回答我自己的问题:移出get_next()循环

4

1 回答 1

3

已解决:移出get_next()循环

于 2017-11-27T10:50:46.573 回答