0

我正在尝试将冻结模型加载到 tensorflow 2.7 我收到以下错误

File "Object_Detector.py", line 53, in __init__
    tf.import_graph_def(od_graph_def, name='')
  File "C:\Python\lib\site-packages\tensorflow\python\util\deprecation.py", line 549, in new_func
    return func(*args, **kwargs)
  File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 464, in _import_graph_def_internal
    graph_def = _ProcessGraphDefParam(graph_def)
  File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 96, in _ProcessGraphDefParam
    raise TypeError('graph_def must be a GraphDef proto.')
TypeError: graph_def must be a GraphDef proto.

代码 :

        detect_model_name = 'Pretrained_Model_SSD'
        PATH_TO_CKPT = detect_model_name + '/saved_model.pb'

        self.detection_graph = tf.compat.v1.Graph()

        # configuration for possible GPU use
        # config = tf.config.experimental
        # config.config.gpu_options.allow_growth = True
        # load frozen tensorflow detection model and initialize 
        # the tensorflow graph
        with self.detection_graph.as_default():
            od_graph_def = tf.compat.v1.GraphDef()
            with tf.compat.v2.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
               od_graph_def = tf.compat.v1.get_default_graph()
               tf.import_graph_def(od_graph_def, name='')

               
            self.sess = tf.compat.v1.Session(graph=self.detection_graph)
            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
              # Each box represents a part of the image where a particular object was detected.
            self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
              # Each score represent how level of confidence for each of the objects.
              # Score is shown on the result image, together with the class label.
            self.scores =self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_detections =self.detection_graph.get_tensor_by_name('num_detections:0')

如何加载此模型(该模型是在 tensorflow 1.15 中训练的。

4

0 回答 0