5

根据Google 的 Derek Chow 最近在 Google Cloud Big Data And Machine Learning 博客上的帖子,我使用 Cloud Machine Learning Engine 训练了一个对象检测器,现在我想使用 Cloud Machine Learning Engine 进行预测。

说明包括将 Tensorflow 图导出为 output_inference_graph.pb 的代码,但不包括如何将 protobuf 格式 (pb) 转换为 gcloud ml-engine predict 所需的 SavedModel 格式。

我查看了Google 的 @rhaertel80关于如何转换“Tensorflow For Poets”图像分类模型的答案以及 Google 的 @MarkMcDonald 提供的关于如何转换“Tensorflow For Poets 2”图像分类模型的答案,但似乎都不适用于博客文章中描述的对象检测器图(pb)。

请问如何转换该对象检测器图(pb)以便可以使用它或 gcloud ml-engine 预测?

4

2 回答 2

2

SavedModel在其结构中包含一个MetaGraphDef。要在 python 中从 GraphDef 创建 SavedModel,您可能需要使用链接中描述的构建器。

export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph_and_variables(sess,
                                       [tag_constants.TRAINING],
                                       signature_def_map=foo_signatures,
                                       assets_collection=foo_assets)
...
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph(["bar-tag", "baz-tag"])
...
builder.save()
于 2017-06-21T00:28:02.037 回答
1

这篇文章救了我!希望对来这里的人有所帮助。我使用导出成功的方法https://stackoverflow.com/a/48102615/6124383

https://github.com/tensorflow/tensorflow/pull/15855/commits/81ec5d20935352d71ff56fac06c36d6ff0a7ae05

def export_model(sess, architecture, saved_model_dir):
  if architecture == 'inception_v3':
    input_tensor = 'DecodeJpeg/contents:0'
  elif architecture.startswith('mobilenet_'):
    input_tensor = 'input:0'
  else:
    raise ValueError('Unknown architecture', architecture)
  in_image = sess.graph.get_tensor_by_name(input_tensor)
  inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
   out_classes = sess.graph.get_tensor_by_name('final_result:0')
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
   signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
  )
   legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
   # Save out the SavedModel.
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
  builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
    },
    legacy_init_op=legacy_init_op)
  builder.save()

#execute this in the final of def main(_):
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)

parser.add_argument(
      '--saved_model_dir',
      type=str,
      default='/tmp/saved_models/1/',
      help='Where to save the exported graph.'
  )
于 2018-08-31T10:13:10.913 回答