7

我有以下代码:

data = np.load("data.npy")
print(data) # Makes sure the array gets loaded in memory
dataset = tf.contrib.data.Dataset.from_tensor_slices((data))

该文件"data.npy"为 3.3 GB。使用 numpy 读取文件需要几秒钟,但创建 tensorflow 数据集对象的下一行需要很长时间才能执行。这是为什么?它在引擎盖下做什么?

4

2 回答 2

5

引用这个答案

np.loadof anpz只是返回一个文件加载器,而不是实际数据。它是一个“惰性加载器”,仅在访问时加载特定的数组。

这就是为什么它很快。

编辑 1:为了进一步扩展这个答案,tensorflow 文档的另一个引用:

如果您的所有输入数据都适合内存,那么从它们创建 a 的最简单方法Dataset是将它们转换为tf.Tensor对象并使用Dataset.from_tensor_slices().

这适用于小型数据集,但会浪费内存——因为数组的内容将被复制多次——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。

该链接还显示了如何有效地做到这一点。

于 2017-10-20T19:58:12.033 回答
0

尝试:

data = np.load("data.npy")
a = tf.placeholder(tf.float32, shape)
dataset = tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.prefetch(buffer_size=1000)
dataset = dataset.batch(128)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer, feed_dict={a: data})

处理大型数据集时,tf.placeholder效果更好。

于 2019-01-22T13:41:18.383 回答