我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数(https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)加载图像,该函数返回一个 tf.data.Dataset 对象:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical")
我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我使用的是 Google Colab,并且可以看到 RAM 使用量在第一个 epoch 期间增长)。然后我尝试使用 Keras 序列,这是将部分数据加载到 RAM 中的建议解决方案(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):
class DatasetGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return tf.data.experimental.cardinality(self.dataset).numpy()
def __getitem__(self, idx):
return list(self.dataset.as_numpy_iterator())[idx]
我用以下方法训练模型:
history = model.fit(DatasetGenerator(train_ds), ...)
问题是getitem ()必须返回一批带索引的数据。但是,我使用的 list() 函数必须将整个数据集放入 RAM,因此当 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。
我的问题:
- 有没有办法实现getitem () (从数据集对象中获取特定批次)而不将整个对象放入内存?
- 如果第 1 项是不可能的,是否有任何解决方法?
提前致谢!