像这样的东西会起作用。我没有对此进行测试,也没有尝试完善它……但是这些零件都在那里,以便您可以按照自己的喜好进行操作。
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np
class LRSetter(Callback):
def __init__(self, start_lr=0, middle_lr=0.001, end_lr=0.00001,
start_mid_batches=200, end_epochs=2000):
self.start_mid_lr = np.linspace(start_lr, middle_lr, start_mid_batches)
#Not exactly right since you'll have gone through a couple epochs
#but you get the picture
self.mid_end_lr = np.linspace(middle_lr, end_lr, end_epochs)
self.start_mid_batches = start_mid_batches
self.epoch_takeover = False
def on_train_batch_begin(self, batch, logs=None):
if batch < self.start_mid_batches:
tf.keras.backend.set_value(self.model.optimizer.lr, self.start_mid_lr[batch])
else:
self.epoch_takeover = True
def on_epoch_begin(self, epoch):
if self.epoch_takeover:
tf.keras.backend.set_value(self.model.optimizer.lr, self.mid_end_lr[epoch])