0

我想改变get_gradientsAdam执行梯度中心化的方法,以下代码来自https://keras.io/examples/vision/gradient_centralization/

我的问题是为什么执行 fit 函数时 print('ttt')方法中的代码什么也不打印?工作并在屏幕上打印 dddd。get_gradientsprint('dddd')

from tensorflow.keras.optimizers import Adam
import tensorflow as tf

class GCAdam(Adam):
    def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False, name='Adam', **kwargs):
        super().__init__(learning_rate=learning_rate, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, amsgrad=amsgrad, name=name, **kwargs)
        print('dddd')
   
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        
        print('ttt')
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads


optimizergc = GCAdam(learning_rate=1e-4)

model.compile(loss="categorical_crossentropy", optimizer=optimizergc, metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1,verbose=1)
4

0 回答 0