1

在对 resnet-50 进行对抗训练时,我遇到了一个奇怪的问题,我不确定这是逻辑错误,还是代码/库中的某个错误。我正在对抗性地训练从 Keras 加载的 resnet-50,使用来自cleverhans 的 FastGradientMethod,并期望对抗性准确度至少提高 90% 以上(可能是 99.x%)。训练算法、训练和攻击参数应该在代码中可见。正如标题中已经指出的那样,问题在于,在第一个 epoch 中训练了 39002 个训练输入中的 3000 个之后,准确率停留在 5%。(德国交通标志识别基准,GTSRB)。

在没有对抗损失函数的情况下进行训练时,在 3000 个样本后准确率不会卡住,而是在第一个 epoch 继续上升 > 0.95。

当用 lenet-5、alexnet 和 vgg19 替换网络时,代码按预期工作,并且实现了与非对抗性 categorical_corssentropy 损失函数绝对可比的精度。我也尝试过仅使用 tf-cpu 和不同版本的 tensorflow 运行该过程,结果始终相同。

获取 ResNet-50 的代码:

def build_resnet50(num_classes, img_size):
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras import Model
    from tensorflow.keras.layers import Dense, Flatten
    resnet = ResNet50(weights='imagenet', include_top=False, input_shape=img_size)
    x = Flatten(input_shape=resnet.output.shape)(resnet.output)
    x = Dense(1024, activation='sigmoid')(x)
    predictions = Dense(num_classes, activation='softmax', name='pred')(x)
    model = Model(inputs=[resnet.input], outputs=[predictions])
    return model

训练:

def lr_schedule(epoch):
    # decreasing learning rate depending on epoch
    return 0.001 * (0.1 ** int(epoch / 10))


def train_model(model, xtrain, ytrain, xtest, ytest, lr=0.001, batch_size=32, 
epochs=10, result_folder=""):
    from cleverhans.attacks import FastGradientMethod
    from cleverhans.utils_keras import KerasModelWrapper
    import tensorflow as tf

    from tensorflow.keras.optimizers import SGD
    from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
    sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True)

    model(model.input)

    wrap = KerasModelWrapper(model)
    sess = tf.compat.v1.keras.backend.get_session()
    fgsm = FastGradientMethod(wrap, sess=sess)
    fgsm_params = {'eps': 0.01,
                   'clip_min': 0.,
                   'clip_max': 1.}

    loss = get_adversarial_loss(model, fgsm, fgsm_params)

    model.compile(loss=loss, optimizer=sgd, metrics=['accuracy'])

    model.fit(xtrain, ytrain,
                    batch_size=batch_size,
                    validation_data=(xtest, ytest),
                    epochs=epochs,
                    callbacks=[LearningRateScheduler(lr_schedule)])

损失函数:

def get_adversarial_loss(model, fgsm, fgsm_params):
    def adv_loss(y, preds):
         import tensorflow as tf

        tf.keras.backend.set_learning_phase(False) #turn off dropout during input gradient calculation, to avoid unconnected gradients

        # Cross-entropy on the legitimate examples
        cross_ent = tf.keras.losses.categorical_crossentropy(y, preds)

        # Generate adversarial examples
        x_adv = fgsm.generate(model.input, **fgsm_params)
        # Consider the attack to be constant
        x_adv = tf.stop_gradient(x_adv)

        # Cross-entropy on the adversarial examples
        preds_adv = model(x_adv)
        cross_ent_adv = tf.keras.losses.categorical_crossentropy(y, preds_adv)

        tf.keras.backend.set_learning_phase(True) #turn back on

        return 0.5 * cross_ent + 0.5 * cross_ent_adv
    return adv_loss

使用的版本:tf+tf-gpu:1.14.0 keras:2.3.1cleverhans:> 3.0.1 - 从 github 提取的最新版本

4

1 回答 1

0

这是我们在 BatchNormalization 上估计移动平均值的方式的副作用。

您使用的训练数据的均值和方差与用于训练 ResNet50 的数据集不同。因为 BatchNormalization 上的动量具有默认值 0.99,只有 10 次迭代它不能足够快地收敛到移动均值和方差的正确值。当 learning_phase 为 1 时,这在训练期间并不明显,因为 BN 使用批次的均值/方差。然而,当我们将 learning_phase 设置为 0 时,在训练期间学习到的不正确的均值/方差值会显着影响准确性。

您可以通过以下方法解决此问题:

  1. 更多迭代

将批次的大小从 32 减少到 16(以在每个 epoch 执行更多更新)并将 epoch 的数量从 10 增加到 250。这样移动平均值和方差将收敛到正确的值。

  1. 改变 BatchNormalization 的势头

保持迭代次数固定,但更改 BatchNormalization 层的动量以更积极地更新滚动均值和方差(不推荐用于生产模型)。

在原始代码片段中,在读取 base_model 和定义新层之间添加以下代码:

# ....
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)

# PATCH MOMENTUM - START
import json
conf = json.loads(base_model.to_json())
for l in conf['config']['layers']:
    if l['class_name'] == 'BatchNormalization':
        l['config']['momentum'] = 0.5


m = Model.from_config(conf['config'])
for l in base_model.layers:
    m.get_layer(l.name).set_weights(l.get_weights())

base_model = m
# PATCH MOMENTUM - END

x = base_model.output
# ....

还建议您尝试此处提供的另一种黑客攻击。

于 2020-05-07T11:44:25.690 回答