0

我想在培训和服务中使用集线器,但我有点困惑如何在同一张图上做到这一点。即我有类似的东西

def build_graph(..., mode, ...):
    tags_and_args= ... # one for training, one for serving
    if mode == 'training':
        hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
        module_output = hub.Module(...)
        hub.register_module_for_export(module_fn, tags_and_args=tags_and_args)

        loss, output = ...

    else:
        module_output = hub.Module(XXX)

我应该从磁盘重新加载模块吗?因此XXX将是我之前保存它的路径。或者它是否以某种方式保存为内存中的图形对象?

我将我的代码称为

estimator.train(...)
exporter = hub.LatestModuleExporter(...)
exporter.export(...)
esimator.export_savedmodel(...)  # for serving
4

1 回答 1

1

您可以在 Estimator 的 model_fn 中使用 hub.Module,而无需导出它。在 Estimator.train() 开始时,模块的变量将从它们的预训练值初始化(很像其他变量是随机初始化的)。之后,模块变量的行为与模型的其他变量非常相似——它们是模型检查点的一部分,并从那里恢复以进行评估、恢复训练或导出到 SavedModel 以供服务,就像任何其他变量一样。

仅当您想要创建可用于另一个单独的 Estimator 的模块的新版本(使用从训练中更新的权重)时才需要导出 hub.Module。

于 2018-08-28T10:28:40.360 回答