我目前正在使用生成器来生成我的训练和验证数据集tf.data.Dataset.from_generator
。我有一个类方法可以为我解决这个问题:
def build_dataset(self, batch_size=16, shuffle=16, validation=None):
train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
if validation is not None:
val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)
问题是传递(validation=validation)
给我import_images
的生成器创建了 Tensorflow 不需要的生成器对象,它给了我错误:
TypeError: `generator` must be callable.
因为我必须传入validation
告诉我的生成器生成单独的训练和验证版本,所以我需要创建同一个生成器的两个版本。它也不允许我传递其他参数来控制训练和验证示例的百分比——这意味着生成器必须是静态的。有什么建议么?