我想改变get_gradientsAdam执行梯度中心化的方法,以下代码来自https://keras.io/examples/vision/gradient_centralization/
我的问题是为什么执行 fit 函数时
print('ttt')方法中的代码什么也不打印?工作并在屏幕上打印 dddd。get_gradientsprint('dddd')
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
class GCAdam(Adam):
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False, name='Adam', **kwargs):
super().__init__(learning_rate=learning_rate, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, amsgrad=amsgrad, name=name, **kwargs)
print('dddd')
def get_gradients(self, loss, params):
# We here just provide a modified get_gradients() function since we are
# trying to just compute the centralized gradients.
grads = []
gradients = super().get_gradients()
print('ttt')
for grad in gradients:
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
grads.append(grad)
return grads
optimizergc = GCAdam(learning_rate=1e-4)
model.compile(loss="categorical_crossentropy", optimizer=optimizergc, metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1,verbose=1)