我正在尝试将冻结模型加载到 tensorflow 2.7 我收到以下错误
File "Object_Detector.py", line 53, in __init__
tf.import_graph_def(od_graph_def, name='')
File "C:\Python\lib\site-packages\tensorflow\python\util\deprecation.py", line 549, in new_func
return func(*args, **kwargs)
File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 405, in import_graph_def
producer_op_list=producer_op_list)
File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 464, in _import_graph_def_internal
graph_def = _ProcessGraphDefParam(graph_def)
File "C:\Python\lib\site-packages\tensorflow\python\framework\importer.py", line 96, in _ProcessGraphDefParam
raise TypeError('graph_def must be a GraphDef proto.')
TypeError: graph_def must be a GraphDef proto.
代码 :
detect_model_name = 'Pretrained_Model_SSD'
PATH_TO_CKPT = detect_model_name + '/saved_model.pb'
self.detection_graph = tf.compat.v1.Graph()
# configuration for possible GPU use
# config = tf.config.experimental
# config.config.gpu_options.allow_growth = True
# load frozen tensorflow detection model and initialize
# the tensorflow graph
with self.detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v2.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
od_graph_def = tf.compat.v1.get_default_graph()
tf.import_graph_def(od_graph_def, name='')
self.sess = tf.compat.v1.Session(graph=self.detection_graph)
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
self.scores =self.detection_graph.get_tensor_by_name('detection_scores:0')
self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
self.num_detections =self.detection_graph.get_tensor_by_name('num_detections:0')
如何加载此模型(该模型是在 tensorflow 1.15 中训练的。