我正在使用 Tensorflow Dataset API 并从 TFRecord 文件中读取数据。我可以使用 map 函数并使用 random_flip_left_right、random_crop 等方法进行数据增强。
但是,当我尝试复制 AlexNet 论文时,我遇到了一个问题。我需要翻转每张图像,然后进行 5 次裁剪(左、上、下、右和中)。
因此输入数据集大小将增加 10 倍。无论如何使用tensorflow数据集API来做到这一点?map() 函数只返回一张图像,我无法增加图像的数量。
请查看我现在拥有的代码。
dataset = dataset.map(parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.map(lambda image, label: (tf.image.random_flip_left_right(image), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.map(lambda image, label: (tf.image.random_crop(image, size=[227, 227, 3]), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size=1000) \
.repeat() \
.batch(256) \
.prefetch(tf.data.experimental.AUTOTUNE)