首先,您必须命名损失以使其可用于提前停止呼叫。如果您的损失变量在估计器中被命名为“损失”,则该行
copyloss = tf.identity(loss, name="loss")
就在它下面会起作用。
然后,使用此代码创建一个挂钩。
class EarlyStopping(tf.train.SessionRunHook):
def __init__(self,smoothing=.997,tolerance=.03):
self.lowestloss=float("inf")
self.currentsmoothedloss=-1
self.tolerance=tolerance
self.smoothing=smoothing
def before_run(self, run_context):
graph = ops.get_default_graph()
#print(graph)
self.lossop=graph.get_operation_by_name("loss")
#print(self.lossop)
#print(self.lossop.outputs)
self.element = self.lossop.outputs[0]
#print(self.element)
return tf.train.SessionRunArgs([self.element])
def after_run(self, run_context, run_values):
loss=run_values.results[0]
#print("loss "+str(loss))
#print("running average "+str(self.currentsmoothedloss))
#print("")
if(self.currentsmoothedloss<0):
self.currentsmoothedloss=loss*1.5
self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
if(self.currentsmoothedloss<self.lowestloss):
self.lowestloss=self.currentsmoothedloss
if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
run_context.request_stop()
print("REQUESTED_STOP")
raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')
这会将指数平滑的损失验证与其最低值进行比较,如果容忍度更高,则停止训练。如果它停止得太早,提高容差和平滑会使它停止得更晚。保持平滑低于 1,否则它永远不会停止。
如果您想根据不同的条件停止,可以将 after_run 中的逻辑替换为其他内容。
现在,将此钩子添加到评估规范中。您的代码应如下所示:
eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#
重要提示:函数 run_context.request_stop() 在 train_and_evaluate 调用中被破坏,并且不会停止训练。所以,我提出了一个价值错误来停止训练。因此,您必须将 train_and_evaluate 调用包装在 try catch 块中,如下所示:
try:
tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
print("training stopped")
如果你不这样做,当训练停止时代码将崩溃并出现错误。