2

我目前正在Tensorflow 2中训练一个大型对象检测模型,并使用梯度磁带自定义训练循环。问题是模型没有改善损失,因为梯度非常低。我使用 cifar10 在一个简单的分类任务中重现了这个问题,发现一个小模型训练得很好,没有问题,而一个更大的模型(VGG16)根本没有改善损失。下面是一些重现问题的代码。

VGG16型号:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPooling2D, BatchNormalization, Input, Concatenate
import os

def create_vgg16(number_classes, include_fully=True, input_shape=(300, 300, 3), input_tensor=None):
    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        img_input = input_tensor
    x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_1')(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_2')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool1')(x)

    x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_1')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_2')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool2')(x)

    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_1')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_2')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_3')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool3')(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_3')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool4')(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_3')(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=(1, 1), padding='same', name='pool5')(x)

    if include_fully:
        x = Flatten(name='flatten')(x)
        x = Dense(4096, activation='relu', name='fc1')(x)
        x = Dense(4096, activation='relu', name='fc2')(x)
        x = Dense(number_classes, activation='softmax', name='predictions')(x)

    if input_tensor is not None:
        inputs = tf.keras.utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input
    model = tf.keras.models.Model(inputs, x, name='vgg16')

    return model

小型CNN模型:

def create_small_cnn(n_classes, input_shape=(32, 32, 3)):
    img_input = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_1')(img_input)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_2')(x)
    x = tf.keras.layers.Flatten(name='flatten')(x)
    x = tf.keras.layers.Dense(16, activation='relu', name='fc1')(x)
    x = tf.keras.layers.Dense(n_classes, activation='softmax', name='softmax')(x)

    model = tf.keras.Model(img_input, x, name='small_cnn')
    return model

训练循环:

def main():
    number_classes = 10
    # Load and one hot encode data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train, x_test
    y_train = tf.reshape(y_train, [-1])
    y_train = tf.one_hot(y_train, number_classes).numpy()
    y_test = tf.reshape(y_test, [-1])
    y_test = tf.one_hot(y_test, number_classes).numpy()


    # Define model
    model = create_vgg16(number_classes, input_shape=(32, 32, 3))
    # model = create_small_cnn(number_classes, input_shape=(32, 32, 3))

    # Instantiate an optimizer to train the model.
    optimizer = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

    # Instantiate a loss function.s
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    # Prepare the metrics.
    train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

    # Prepare the training dataset.
    batch_size = 64
    train_dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(x_train/255, tf.float32),
       tf.cast(y_train,tf.int64)))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

    # Prepare the validation dataset.
    val_dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(x_test/255, tf.float32),
       tf.cast(y_test,tf.int64)))
    val_dataset = val_dataset.shuffle(buffer_size=1024).batch(batch_size)

    model.summary()

    for epoch in range(100):
      print('Start of epoch %d' % (epoch,))
      for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
          logits = model(x_batch_train)
          loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        train_acc_metric(y_batch_train[0], logits[0][:-1])

        if step % 200 == 0:
            print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))


      # Display metrics at the end of each epoch.
      train_acc = train_acc_metric.result()
      print('Training acc over epoch: %s' % (float(train_acc),))
      # Reset training metrics at the end of each epoch
      train_acc_metric.reset_states()

      # Run a validation loop at the end of each epoch.
      for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val)

        val_acc_metric(y_batch_val[0], val_logits[0][:-1])
      val_acc = val_acc_metric.result()
      val_acc_metric.reset_states()
      print('Validation acc: %s' % (float(val_acc),))


if __name__ == '__main__':
    main()

如果您运行显示的代码,您将看到在使用小型 CNN 模型时网络训练良好。但另一方面,它不适用于使用标准 VGG16 模型进行相同预处理的完全相同的数据集。更令人困惑的是,VGG 模型在使用model.fit而不是使用梯度磁带的自定义训练循环时将训练得非常好。

有谁知道为什么会这样以及如何解决这个问题?

4

0 回答 0