我正在评估 tensorflow 对象检测 API。我查看了在 Internet 上找到的几篇文章,我能够成功地训练模型并评估对象检测。
当我开始训练时,大约有 100 张图像被标记。我想在 TFrecord 中添加两个或三个新闻图像。添加这些图像后,我应该删除我的 model_output 目录并启动 train.py 还是可以在现有检查点之上执行此操作?
我正在评估 tensorflow 对象检测 API。我查看了在 Internet 上找到的几篇文章,我能够成功地训练模型并评估对象检测。
当我开始训练时,大约有 100 张图像被标记。我想在 TFrecord 中添加两个或三个新闻图像。添加这些图像后,我应该删除我的 model_output 目录并启动 train.py 还是可以在现有检查点之上执行此操作?
您可以生成一个train.record
包含希望训练的新图像的新文件,并且可以使用之前保存的检查点来恢复训练。您需要做的就是将 的 更改input_path
为tf_record_input_reader
指向新train.record
文件,并将fine_tune_checkpoint
更改为类似的东西,<path>/model.ckpt-XXX
而不是您第一次开始训练时使用的原始检查点。
希望你觉得这很有帮助。如果您遇到任何问题,请告诉我
从官方文档可以看到,Tensorflow 支持从检查点恢复状态:
https://www.tensorflow.org/programmers_guide/saved_model
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
如果您不添加或删除类,则不需要重新启动模型。但是请注意在您之前的训练期间学习率是否发生了变化。如果您的调度程序大幅降低学习率,您可能会因为学习率太小而无法学习新图像。