我正在使用更快的 RCNN 模型来训练对象检测器,使用管道配置文件。我知道可以通过直接取消(ctrl+c)来停止训练。我希望训练根据损失值自动停止。如何才能做到这一点?我知道在监控时代时可以使用 keras 回调。使用配置文件和预训练模型(监控步骤)时是否有任何此类选项。谢谢。
1 回答
It might just be a hack, but I found a solution to my question.
The Object detector requires tf_slim
package to be installed. And within the tf_slim
package, there is a module called learning.py
.
The complete path to this might look something like this: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py
Here, in the learning.py
, starting Line 764, the code looks something like this:
try:
while not sv.should_stop():
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
I wrote a small if
statement to check the maximum value for the last five values of the total_loss
, and if below a certain threshold (in this case 3), make should_stop
True
. This is shown below:
try:
total_loss_list = []
while not sv.should_stop():
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
total_loss_list.append(total_loss)
if len(total_loss_list) > 5:
if max(total_loss_list[-5:]) < 3:
should_stop = True
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
If the loss values are continuously below 3 for five steps, then the training stops. The downside to this is that, the package distribution of tf_slim
has to be altered. And every time you work on a new object detection problem, this threshold loss value will change. A better way would be to use a configuration file where you supply the threshold loss value. But I'm stopping here for now.
If anyone else has a better solution, please share.
I hope this helps someone. Thank you!