1

我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数(https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)加载图像,该函数返回一个 tf.data.Dataset 对象:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=1234,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  label_mode="categorical")

我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我使用的是 Google Colab,并且可以看到 RAM 使用量在第一个 epoch 期间增长)。然后我尝试使用 Keras 序列,这是将部分数据加载到 RAM 中的建议解决方案(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):

  class DatasetGenerator(tf.keras.utils.Sequence):
      def __init__(self, dataset):
          self.dataset = dataset

      def __len__(self):
          return tf.data.experimental.cardinality(self.dataset).numpy()

      def __getitem__(self, idx):
          return list(self.dataset.as_numpy_iterator())[idx]

我用以下方法训练模型:

history = model.fit(DatasetGenerator(train_ds), ...)

问题是getitem ()必须返回一批带索引的数据。但是,我使用的 list() 函数必须将整个数据集放入 RAM,因此当 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。

我的问题:

  1. 有没有办法实现getitem () (从数据集对象中获取特定批次)而不将整个对象放入内存?
  2. 如果第 1 项是不可能的,是否有任何解决方法?

提前致谢!

4

1 回答 1

2

我了解您担心将完整的数据集保存在内存中。

不用担心,tf.data.DatasetAPI 非常高效,它不会将您的完整数据集加载到内存中。

在内部,它只是创建了一系列函数,当使用它调用时,model.fit()它只会加载内存中的批次,而不是完整的数据集。

您可以在此链接中阅读更多内容,我正在粘贴文档中的重要部分。

tf.data.Dataset API 支持编写描述性和高效的输入管道。数据集的使用遵循一个常见的模式:

从输入数据创建源数据集。应用数据集转换来预处理数据。迭代数据集并处理元素。迭代以流式方式发生,因此整个数据集不需要放入内存中。

从最后一行可以了解到,tf.data.DatasetAPI 不会将完整的数据集加载到内存中,而是一次加载一批。

您必须执行以下操作来创建数据集的批次。

train_ds.batch(32)

这将创建 size 的批次32。您还可以使用预取来准备一批用于训练的批次。这消除了模型在训练一批并等待另一批后空闲的瓶颈。

train_ds.batch(32).prefetch(1)

您还可以使用cacheAPI 使您的数据管道更快。它将缓存您的数据集并使训练速度更快。

train_ds.batch(32).prefetch(1).cache()

generator因此,简而言之,如果您担心将整个数据集加载到内存中,则不需要tf.data.DatasetAPI 来处理它。

我希望我的回答能找到你。

于 2020-07-28T20:20:23.867 回答