我正在研究图像分类 CNN,我想知道如何使用 ImageGenerator 获得 x_train 和 y_train 形式。这样做的原因是我想将我的模型拟合为 fit note fit.generator
trainDataGen = ImageDataGenerator(rescale= 1./255,
rotation_range =30,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=False,
vertical_flip=False,
fill_mode='nearest',
)
trainGenSet = trainDataGen.flow_from_directory(
path +'Train',
target_size=(28,28),
batch_size=32,
class_mode='categorical',
color_mode='grayscale',
)
x_train, y_train = trainGenSet.next()
由于批量大小为 32,print(x_train) 为 (32,28,28,1)。我总共有 5000 个训练数据集,我想为 print(x_train) 获取 (5000,28,28,1)