3

我正在尝试加载从 pusher tfx 管道生成的 .pb 文件。我正在使用以下函数加载文件,但我从函数中得到以下错误。请帮忙。

错误:

<ipython-input-40-af7ef7ac8a8b> in load_model()
      2     with tf.compat.v2.io.gfile.GFile('/home//saved_model.pb', "rb") as f:
      3         graph_def = tf.compat.v1.GraphDef()
----> 4         graph_def.ParseFromString(f.read())
      5 
      6     with tf.Graph().as_default() as graph:
DecodeError: Error parsing message

功能

def load_model():
    with tf.compat.v2.io.gfile.GFile('/home/saved_model.pb', "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
    return graph
4

1 回答 1

9

嘿,您可以尝试使用此代码加载 tensorflow 提供的 .pb 文件:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
    model_filename ='saved_model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
于 2020-09-19T21:27:27.507 回答