14

如何多次输出数据集中的值?(数据集由 TensorFlow 的 Dataset API 创建)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

错误信息:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

如何使这项工作?

4

4 回答 4

25

首先,我建议您阅读数据集指南。描述了 DataSet API 的所有细节。

您的问题是关于多次迭代数据。这里有两个解决方案:

  1. 一次迭代所有时期,没有关于各个时期结束的信息
import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. 第二个选项通知您结束每个时代,所以您可以前。检查验证损失:
import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)
于 2017-11-10T06:51:30.753 回答
3

如果您的 tensorflow 版本是 1.3+,我推荐使用高级 API tf.train.MonitoredTrainingSession。这个sessAPI创建的可以自动检测tf.errors.OutOfRangeErrorsess.should_stop()对于大多数训练情况,您需要打乱数据并在每一步中获取一个批次,我在下面的代码中添加了这些。

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)
于 2017-12-21T04:15:26.343 回答
3

尝试这个

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

每当数据集迭代器到达数据末尾时,它会引发 tf.errors.OutOfRangeError,您可以使用 except 捕获它并从头开始数据集。

于 2018-03-27T18:57:58.660 回答
2

与 Toms 的回答类似,对于 tensorflow 2+,您可以使用以下高级 API 调用(他的回答中提出的代码在 tensorflow 2+ 中已弃用):

epoch = 10
batch_size = 32
dataset = tf.data.Dataset.range(100) 

dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.repeat(epoch)

num_batch = 0
for batch in dataset:
        num_batch += 1
        print("Num Batch: ", num_batch)

跟踪进度的一个有用调用是将迭代的批次总数(在调用和调用之后使用):batchrepeat

num_batches = tf.data.experimental.cardinality(dataset)

请注意,目前(tensorflow 2.1),该cardinality方法仍处于实验阶段。

于 2020-02-05T12:48:42.720 回答