1

TensorFlow Estimator 易于使用参数服务器策略进行分布式训练。但我无法使用参数服务器策略进行预测。我找不到任何资源来介绍这部分。

预测示例代码:

    run_config = tf.estimator.RunConfig()
    model = tf.estimator.Estimator(
        model_fn=self.model_fn,
        model_dir=self._config.model_path,
        config=run_config,
        params=self.params())
    results = model.predict(
        input_fn=lambda: test_data.build(
            batch_size=self._config.eval_batch_size,
            num_epochs=1))

TF_CONFIG:

{'task': {'index': '0', 'type': 'ps'}, 'cluster': {'chief': ['127.0.0.1:2320'], 'ps': ['127.0.0.1:2220', '127.0.0.1:2221']}}
{'task': {'index': '1', 'type': 'ps'}, 'cluster': {'chief': ['127.0.0.1:2320'], 'ps': ['127.0.0.1:2220', '127.0.0.1:2221']}}
{'task': {'index': '0', 'type': 'chief'}, 'cluster': {'chief': ['127.0.0.1:2320'], 'ps': ['127.0.0.1:2220', '127.0.0.1:2221']}}

结果:PS和Woker都做了预测。

有什么建议吗?非常感谢。

4

1 回答 1

0

Estimator predict中,每个 ps 和 worker 用于MonitoredSession启动一个从现有检查点恢复的节点。为了进行分布式预测,您可以参考Estimator training

  • 开始ps
  • run_worker创建MonitoredTrainingSession而不是MonitoredSession
    • 记得启动工作服务器。
  • estimator.predict接收一个pathfor checkpoint,MonitoredTrainingSession接收一个directoryfor checkpoint。

您可以成功启动所有服务器和分布式预测。但是会有警告,比如全球步数没有增加。

Github上的详细代码

于 2020-08-17T08:35:03.893 回答