0

目标:TFX -> TF Lite Converter -> 在移动/物联网设备上部署模型

我目前正在学习TensorFlow Extended with its Chicago Taxi Pipeline Example。管道已完成运行(尽管经历了很多困难),并且Pusher组件已发出一个Tensorflow SavedModel文件 (.pb)。

但是,这里遇到了一个新问题:通过 Tensorflow nightly/1.13.1(都尝试过)和 Python 2.7.6,我可以通过一些简单的方法生成、保存和加载SavedModel(用于测试实用程序的 mnist 数字数据模型) python 代码,例如saved_model.simple_saveand ,但是在应用TFX Pushersaved_model.loader.load发出的模型时,我一直遇到错误,如下所示。

(也许我在 TFX 管道上做错了什么?)

我使用的代码:

import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
    tf.compat.v1.saved_model.loader.load(sess, ["serve"], "/home/tigerpaws/taxi/serving_model/taxi_simple/1553187887")#"/home/tigerpaws/saved_model_example/model")
    graph=tf.get_default_graph()

错误:

KeyError                                  Traceback (most recent call last)
<ipython-input-11-a6978b82c3d2> in <module>()
      1 with tf.Session(graph=tf.Graph()) as sess:
----> 2     tf.compat.v1.saved_model.loader.load(sess, ["serve"], "/home/tigerpaws/taxi/serving_model/taxi_simple/1553187887")#"/home/tigerpaws/saved_model_example/model")
      3     graph=tf.get_default_graph()

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/util/deprecation.pyc in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load(sess, tags, export_dir, import_scope, **saver_kwargs)
    267   """
    268   loader = SavedModelLoader(export_dir)
--> 269   return loader.load(sess, tags, import_scope, **saver_kwargs)
    270 
    271 

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load(self, sess, tags, import_scope, **saver_kwargs)
    418     with sess.graph.as_default():
    419       saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420                                  **saver_kwargs)
    421       self.restore_variables(sess, saver, import_scope)
    422       self.run_init_ops(sess, tags, import_scope)

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load_graph(self, graph, tags, import_scope, **saver_kwargs)
    348     with graph.as_default():
    349       return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
--> 350           meta_graph_def, import_scope=import_scope, **saver_kwargs)
    351 
    352   def restore_variables(self, sess, saver, import_scope=None):

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
   1455           import_scope=import_scope,
   1456           return_elements=return_elements,
-> 1457           **kwargs))
   1458 
   1459   saver = _create_saver_from_imported_meta_graph(

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.pyc in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
    804         input_map=input_map,
    805         producer_op_list=producer_op_list,
--> 806         return_elements=return_elements)
    807 
    808     # Restores all the other collections.

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/util/deprecation.pyc in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    397   if producer_op_list is not None:
    398     # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
--> 399     _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
    400 
    401   graph = ops.get_default_graph()

/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
    157     # Remove any default attr values that aren't in op_def.
    158     if node.op in producer_op_dict:
--> 159       op_def = op_dict[node.op]
    160       producer_op_def = producer_op_dict[node.op]
    161       # We make a copy of node.attr to iterate through since we may modify

KeyError: u'BucketizeWithInputBoundaries'

还有另一种尝试,我尝试将SavedModel转换为GraphDef(冻结图),以便我可以再试一次转换器。转换需要 a output_node_names,我不知道;我也找不到模型在代码中的保存位置(所以也许我可以在某处发现输出节点名称)。

关于这个问题或替代方法的任何想法?提前致谢。

编辑:有人可以帮助创建标签吗?我的声望还没有达到1500,但是这个问题真的是关于tfx/tensorflow-extended

4

1 回答 1

0

对造成的混乱深表歉意;该问题实际上是由 SavedModel 文件的读取引起的。

在 SavedModel 中有一个操作BucketizeWithInputBoundaries,它没有在 中定义op_dict

这仍然在 Google 的 TODO 列表中,在他们的两个脚本中进行了评论。

这里这里。(Github 链接):

# TODO(jyzhao): BucketizeWithInputBoundaries error without this.

导入指定的脚本后,这个问题就解决了。

from tensorflow.contrib.boosted_trees.python.ops import quantile_ops  # pylint: disable=unused-import
于 2019-05-25T10:09:23.870 回答