我正在编写自己的回调以根据某些自定义条件停止训练。一旦满足条件,EarlyStopping 就会停止训练:
self.model.stop_training = True
例如来自https://www.tensorflow.org/guide/keras/custom_callback
class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """当损失达到最小值时停止训练,即损失停止减少。
参数: 耐心:达到 min 后等待的 epoch 数。在这个数量没有改善之后,训练就停止了。"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
问题是,它不适用于 tensorflow 2.2 和 2.3。任何解决方法的想法?还有什么办法可以停止在 tf 2.3 中训练模型?