1

我目前正在玩 Keras 中的数据增强。我的模型如下所示:

Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 1024, 1280, 3)]   0         
_________________________________________________________________
lambda_1 (Lambda)            (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 126, 126, 32)      896       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 42, 42, 32)        9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 21, 21, 32)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 21, 21, 32)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 14112)             0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               1806464   
_________________________________________________________________
dropout_3 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 6)                 774       
=================================================================
Total params: 1,817,382
Trainable params: 1,817,382
Non-trainable params: 0
_________________________________________________________________

lambda 层基本上是缩放图像。

训练工作正常,但是,我没有足够的数据,因此,泛化很糟糕。因此,我尝试了数据增强。

image_gen = ImageDataGenerator(
        rotation_range=20,
        zoom_range=0.15,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.15,
        horizontal_flip=True,
        fill_mode="nearest")
image_gen.fit(x_train, augment=True)
summary = model.fit_generator(image_gen.flow(x_train, y_train, batch_size=64), epochs=30, validation_data=(x_valid, y_valid) )

但现在拟合不再收敛

Epoch 1/30
3/3 [==============================] - 37s 12s/step - loss: 6.5018 - categorical_accuracy: 0.6000 - val_loss: 10.0990 - val_categorical_accuracy: 0.5217
Epoch 2/30
3/3 [==============================] - 34s 11s/step - loss: 6.5018 - categorical_accuracy: 0.6000 - val_loss: 10.0990 - val_categorical_accuracy: 0.5217
Epoch 3/30
3/3 [==============================] - 32s 11s/step - loss: 6.2290 - categorical_accuracy: 0.6000 - val_loss: 7.7606 - val_categorical_accuracy: 0.5217
Epoch 4/30
3/3 [==============================] - 36s 12s/step - loss: 6.7746 - categorical_accuracy: 0.6000 - val_loss: 10.0990 - val_categorical_accuracy: 0.5217
Epoch 5/30
3/3 [==============================] - 35s 12s/step - loss: 6.7746 - categorical_accuracy: 0.6000 - val_loss: 7.7606 - val_categorical_accuracy: 0.5217
Epoch 6/30
3/3 [==============================] - 35s 12s/step - loss: 7.3203 - categorical_accuracy: 0.6000 - val_loss: 10.0990 - val_categorical_accuracy: 0.5217
Epoch 7/30
1/3 [=========>....................] - ETA: 4s - loss: 2.6863 - categorical_accuracy: 0.8333

我手动检查了生成的图像,它们看起来不错

4

0 回答 0