3

我正在使用 tensorflow 的数据集 API。并用简单的案例测试我的代码。下面显示了我使用的简单代码。问题是,当数据集大小较小时,从数据集 API 返回的大小似乎不一致。我确信有一个适当的方法来处理它。但即使我阅读了该页面和教程中的所有功能,我也找不到。

import numpy as np
import tensorflow as tf

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(16)
dataset = dataset.repeat()

iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(dataset)

with tf.Session() as sess:
    sess.run(training_init_op)
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))

数据集是灰度视频。共有 24 个视频序列,步长均为 200。帧大小为 64 x 64 和单通道。我将批量大小设置为 16,缓冲区大小设置为 100。但代码的结果是,

(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

返回的视频大小是16或8。我猜是因为原始数据大小很小,24,当它到达数据末尾时,API只是返回剩下的。

但我不明白。我还将缓冲区大小设置为 100。这意味着应该提前用小数据集填充缓冲区。并且从该缓冲区中,API 应该选择批量大小为 16 的 next_element。

当我在 tensorflow 中使用队列类型的 API 时,我没有遇到这个问题。无论原始数据的大小是多少,总有一天迭代器会到达数据集的末尾。我想知道其他人如何使用这个 API 解决这个问题。

4

2 回答 2

6

尝试调用repeat()之前batch()

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.repeat()
dataset = dataset.batch(16)

我得到的结果:

(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
于 2017-10-12T13:18:51.357 回答
0

您可以使用下面的代码来解决问题:

batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))
于 2018-10-30T08:49:00.300 回答