17

例如,在我创建操作后,通过操作输入批处理数据并运行操作,tf.train.batch 是否会自动将另一批数据输入会话?

我问这个是因为 tf.train.batch 有一个属性,allow_smaller_final_batch该属性使得最终批次可以加载为小于指定批次大小的大小。这是否意味着即使没有循环,下一批也可以自动喂食?从教程代码中我很困惑。当我加载一个批次时,我实际上得到了一个形状 [batch_size, height, width, num_channels] 的批次大小,但文档说它Creates batches of tensors in tensors.另外,当我阅读tf-slim walkthrough tutorial中的教程代码时,其中有一个名为 load_batch 的函数,只返回 3 个张量:images, images_raw, labels. 文档中解释的“批次”数据在哪里?

感谢您的帮助。

4

3 回答 3

18

... tf.train.batch 是否会自动将另一批数据输入会话?

不会。没有任何事情会自动发生。您必须sess.run(...)再次调用以加载新批次。

这是否意味着即使没有循环,下一批也可以自动喂食?

tf.train.batch(..),将始终加载batch_size张量。例如,如果您有 100 个图像和一个,batch_size=30那么您将有 3*30 个批次,因为您可以sess.run(batch)在输入队列从头开始(或停止如果epoch=1)之前调用 3 次。这意味着您错过了100-3*30=10训练中的样本。如果您不想错过它们tf.train.batch(..., allow_smaller_final_batch=True),现在可以这样做,在输入队列重新启动之前,您将拥有 3x 30-sample-batch 和 1x 10-sample-batch。

让我也用一个代码示例来详细说明:

queue = tf.train.string_input_producer(filenames,
        num_epochs=1) # only iterate through all samples in dataset once

reader = tf.TFRecordReader() # or any reader you need
_, example = reader.read(queue)

image, label = your_conversion_fn(example)

# batch will now load up to 100 image-label-pairs on sess.run(...)
# most tf ops are tuned to work on batches
# this is faster and also gives better result on e.g. gradient calculation
batch = tf.train.batch([image, label], batch_size=100)

with tf.Session() as sess:
    # "boilerplate" code
    sess.run([
        tf.local_variables_initializer(),
        tf.global_variables_initializer(),
    ])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        # in most cases coord.should_stop() will return True
        # when there are no more samples to read
        # if num_epochs=0 then it will run for ever
        while not coord.should_stop():
            # will start reading, working data from input queue
            # and "fetch" the results of the computation graph
            # into raw_images and raw_labels
            raw_images, raw_labels = sess.run([images, labels])
    finally:
        coord.request_stop()
        coord.join(threads)
于 2017-01-16T12:28:07.553 回答
0

每次要加载下一个批次时,您都需要调用 sess.run 并将批次传递给它。请参阅下面的代码。

img = [0,1,2,3,4,5,6,7,8]
lbl = [0,1,2,3,4,5,6,7,8]
images = tf.convert_to_tensor(img)
labels = tf.convert_to_tensor(lbl)
input_queue = tf.train.slice_input_producer([images,labels])
sliced_img = input_queue[0]
sliced_lbl = input_queue[1]

img_batch, lbl_batch = tf.train.batch([sliced_img,sliced_lbl], batch_size=3)
with tf.Session() as sess:
    coord   = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(0,3): #batch size
        image_batch,label_batch = sess.run([img_batch,lbl_batch ])
        print(image_batch, label_batch)

    coord.request_stop()
    coord.join(threads)

答案是这样的:

[4,1,8] [4,1,8]

[2,3,7] [2,3,7]

[2,6,8] [2,6,8]

于 2019-02-06T16:18:29.803 回答
0

我对来自https://github.com/tensorflow/models/blob/master/research/slim/slim_walkthrough.ipynb的代码和来自上述帖子的 bodokaiser 答案进行了修改。请注意,这是来自https://github.com/tensorflow/models/tree/master/research/slim、eval_image_classifier.py上的评估脚本。对 eval_image_classifier.py 代码最重要的修改是将 num_epochs=1 添加到 DatasetDataProvider 行。这样,所有图像都将被访问一次以进行推理。

provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    shuffle=False,
    common_queue_capacity=2 * FLAGS.batch_size,
    common_queue_min=FLAGS.batch_size, num_epochs=1)
[image, label] = provider.get(['image', 'label'])
images, labels = tf.train.batch(
    [image, label],
    batch_size=FLAGS.batch_size,
    num_threads=FLAGS.num_preprocessing_threads,
    capacity=1 * FLAGS.batch_size)
with tf.Session() as sess:
     sess.run([tf.local_variables_initializer(),
               tf.global_variables_initializer(),])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
        while not coord.should_stop():
            np_image, np_label = sess.run([images, labels])
    except:
        coord.request_stop()
        coord.join(threads)
于 2021-01-01T04:13:15.237 回答