我发现了一个 PyTorch 实现,它将批规范momentum
参数从0.1
第一个时期衰减到0.001
最后一个时期。有关如何使用momentum
TF2 中的批处理规范参数执行此操作的任何建议?(即,开始于0.9
,结束于0.999
)例如,这是在 PyTorch 代码中完成的:
# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)
# model class function
def set_bn_momentum(self, momentum):
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum
解决方案:
下面选择的答案在使用tf.keras.Model.fit()
API 时提供了一个可行的解决方案。但是,我使用的是自定义训练循环。这是我所做的:
在每个时代之后:
mi = 1 - initial_momentum # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)
set_bn_momentum 函数(归功于本文):
def set_bn_momentum(model, momentum):
for layer in model.layers:
if hasattr(layer, 'momentum'):
print(layer.name, layer.momentum)
setattr(layer, 'momentum', momentum)
# When we change the layers attributes, the change only happens in the model config file
model_json = model.to_json()
# Save the weights before reloading the model.
tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
model.save_weights(tmp_weights_path)
# load the model from the config
model = tf.keras.models.model_from_json(model_json)
# Reload the model weights
model.load_weights(tmp_weights_path, by_name=True)
return model
这种方法不会显着增加训练程序的开销。