我正在研究一个基于cifar10
数据集的小型项目。我已经从图像增强技术中加载数据tfds.load(...)
并练习。
当我使用tf.data.Dataset
对象,这是我的数据集时,实时数据增强是非常无法实现的,因此我想将所有功能传递到tf.keras.preprocessing.image.ImageDataGenerator.flow(...)
以获得实时增强的功能。
但是这个flow(...)
方法接受与tf.data.Dataset
对象没有任何关系的 NumPy 数组。
有人可以在这方面(或任何替代方法)指导我吗?我该如何进一步进行?
tf.image
转换是实时的吗?如果没有,除了 ,还有什么最好的方法ImageDataGenerator.flow(...)
?
我的代码:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator
splitting = tfds.Split.ALL.subsplit(weighted=(70, 20, 10))
dataset_cifar10, dataset_info = tfds.load(name='cifar10',
split=splitting,
as_supervised=True,
with_info=True)
train_dataset, valid_dataset, test_dataset = dataset_cifar10
BATCH_SIZE = 32
train_dataset = train_dataset.batch(batch_size=BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=1)
image_generator = ImageDataGenerator(rotation_range=45,
width_shift_range=0.15,
height_shift_range=0.15,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
rescale=1./255)
train_dataset_generator = image_generator.flow(...)
...