3

我正在尝试修改使用 TensorFlow (v1.10) 中的 Estimator 类的程序,并且我想在每次评估时访问评估指标结果,以便仅在达到新的最大值时才能复制检查点文件.

我的一个想法是创建一个继承自 的类SessionRunHook,在方法中完成我想要的工作after_run。根据文档,我可以指定传递给after_runusing的内容before_run。但是,我找不到从传入的信息中访问我想要的评估指标结果的方法before_run

我查看了Estimator代码,它似乎正在将结果写入摘要文件,所以我的另一个想法是在after_run方法中读回它,但摘要 api似乎没有提供任何读取操作。

还有其他方法可以实现我想做的吗?不使用Estimator该类不是一种选择,因为这将涉及对我正在使用的代码进行重大更改。

4

1 回答 1

2

检查点与导出不同。检查点与故障恢复有关,涉及保存完整的训练状态(权重、全局步数等)。

在您的情况下,我建议您导出。导出的模型将写入一个名为“exporter”的目录,服务输入函数指定最终用户将向预测服务提供什么。

您可以使用“Best Exporter”类来导出性能最佳的模型:

https://www.tensorflow.org/api_docs/python/tf/estimator/BestExporter

此类导出最佳模型的服务图和检查点。

此外,每当新模型优于任何现有模型时,它都会执行模型导出。

于 2018-10-14T22:21:37.957 回答