162

tf.app.run()Tensorflow 翻译演示如何工作?

tensorflow/models/rnn/translate/translate.py中,有对 的调用tf.app.run()。它是如何处理的?

if __name__ == "__main__":
    tf.app.run() 
4

6 回答 6

149
if __name__ == "__main__":

表示当前文件在 shell 下执行,而不是作为模块导入。

tf.app.run()

正如你可以通过文件看到的app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们逐行分解:

flags_passthrough = f._parse_flags(args=args)

这确保了你通过命令行传递的参数是有效的,例如 python my_model.py --data_dir='...' --max_iteration=10000,这个特性实际上是基于 python 标准argparse模块实现的。

main = main or sys.modules['__main__'].main

main右边的第=一个是当前函数的第一个参数run(main=None, argv=None) 。Whilesys.modules['__main__']表示当前正在运行的文件(例如my_model.py)。

所以有两种情况:

  1. 你没有main函数my_model.py然后你必须打电话tf.app.run(my_main_running_function)

  2. 你有一个main函数my_model.py。(大多数情况下都是如此。)

最后一行:

sys.exit(main(sys.argv[:1] + flags_passthrough))

确保使用解析的参数正确调用您的main(argv)或函数。my_main_running_function(argv)

于 2016-11-23T13:59:49.353 回答
77

它只是一个非常快速的包装器,可以处理标志解析,然后分派到您自己的主程序。见代码

于 2015-11-14T03:29:18.993 回答
9

没有什么特别的tf.app。这只是一个通用的入口点脚本,它

使用可选的“main”函数和“argv”列表运行程序。

它与神经网络无关,它只是调用主函数,将任何参数传递给它。

于 2017-04-29T04:52:55.777 回答
6

简单来说,工作tf.app.run()首先设置全局标志以供以后使用,例如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后使用一组参数运行您的自定义主函数。

例如,在TensorFlow NMT代码库中,训练/推理程序执行的第一个入口点从此时开始(参见下面的代码)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使用 解析参数后argparsetf.app.run()运行定义如下的函数“main”:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置全局使用的标志之后,tf.app.run()只需运行作为参数main传递给它argv的函数。

PS:正如萨尔瓦多·达利(Salvador Dali)的回答所说,我想这只是一个很好的软件工程实践,尽管我不确定 TensorFlow 是否执行了main比使用普通 CPython 运行的函数的任何优化运行。

于 2017-10-06T19:53:01.247 回答
3

Google 代码很大程度上依赖于在库/二进制文件/python 脚本中访问的全局标志,因此 tf.app.run() 解析出这些标志以在 FLAGs(或类似的东西)变量中创建一个全局状态,然后调用 python main( ) 正如它应该。

如果他们没有对 tf.app.run() 进行此调用,那么用户可能会忘记进行 FLAG 解析,从而导致这些库/二进制文件/脚本无法访问他们需要的 FLAG。

于 2018-12-02T22:11:19.477 回答
1

2.0兼容答案:如果要使用tf.app.run()in Tensorflow 2.0,我们应该使用命令,

tf.compat.v1.app.run()或者您可以使用tf_upgrade_v21.x代码转换为2.0.

于 2020-01-21T09:44:06.007 回答