1

Keras 的 ImageDataGenerator 看起来非常适合简单地逐步加载图像并将迭代器传递给 model.fit 函数。然而,它似乎只适用于图像和分类任务。

我想做回归,即我的标签也是与训练集形状相同的数组。在实践中,它们是像图像一样的多维(>1 个通道)数组,但它们不是图像。

关于使用什么类来简单地将批量数据发送到 keras model.fit() 以训练深度神经网络的任何建议?

当然,问题是我的数据集太大而无法放入内存,这就是我需要使用这些生成器/迭代器的原因。

4

2 回答 2

1

您的情况的最佳解决方案是使用tf.data.Dataset().

虽然可能需要相对较短的时间来适应它,但它是加载数据并使用 model.fit() 的推荐方法。

您可以在此处查阅文档:https ://www.tensorflow.org/api_docs/python/tf/data/Dataset

它是新的、快速的、设计精美且易于扩展的。

例如,对于您的问题,您可能想要使用tf.data.Dataset.from_tensor_slices(); 我会让你发现它的特点:D。

于 2020-04-10T08:23:17.517 回答
0

一个快速的解决方案是使用 Colab,其 GPU 实例有 24 GB RAM 可供使用。当你像我在这里做的那样加载 numpy 数组时,你也可以减少你的内存

于 2020-04-10T08:02:02.550 回答