0

我试图复制这篇论文。我使用带有学习率函数的 SGDW:

lr = base_lr * (1 + gamma * iter)^(-power)

我的自定义时间表:

class Example(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(
        self,
        initial_learning_rate=1e-5,
        gamma=0.001,
        power=0.75,
        decay_steps=1,
        verbose=1,
        name="Example",
    ):
        super(Example, self).__init__()
        self.initial_learning_rate = initial_learning_rate
        self.gamma = gamma
        self.power = power
        self.verbose = verbose
        self.decay_steps = decay_steps
        self.name = name

    def __call__(self, step):
        with ops.name_scope_v2(self.name or "Example") as name:
            initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
                self.initial_learning_rate, name="initial_learning_rate"
            )
            _dtype = initial_learning_rate.dtype
            gamma = math_ops.cast(self.gamma, _dtype)
            power = math_ops.cast(self.power, _dtype)
            decay_steps = math_ops.cast(self.decay_steps, _dtype)

            global_step_recomp = math_ops.cast(step, _dtype)
            iteration = math_ops.divide(x=global_step_recomp, y=decay_steps)

            lr = math_ops.multiply(
                x=initial_learning_rate,
                y=math_ops.pow(
                    x=math_ops.add(
                        x=1.0, 
                        y=math_ops.multiply(
                            x=gamma, 
                            y=iteration
                        )
                    ),
                    y=-power,
                ),
                name=name,
            )
            return lr

我的优化器:

optimizer = tfa.optimizers.SGDW(
            learning_rate=Example(
                initial_learning_rate=1e-5,
                gamma=1e-3,
                power=0.75,
                verbose=smoke_test,
            ),
            weight_decay=5e-4,
            momentum=0.9,
        )

然而,损失和平均绝对误差在 30 个 epoch 后呈爆炸式增长。这是训练过程的图表Learning Rate on epochTrain and Validation Loss on epochTrain and Validation MAE on Epoch。这里一定有什么问题。我尝试了不同的优化器和调度器,一切都很好。

4

0 回答 0