1

我正在关注一个使用 tensorflow 的 1.15.0 对象检测 API 的示例。本教程在以下方面进行了明确的说明:

  • 如何下载模型
  • 如何使用 .xml 文件加载自定义数据库,从中制作 .cvs 文件,然后 .record 文件
  • 如何配置训练管道
  • 如何获得张量板图
  • 如何训练净节省检查点(使用 model_main.py)
  • 如何导出(保存)模型(使用 export_inference_graph.py)

但是,我无法完成的是加载保存的模型以使用它。我试过了tf.saved_model.loader.load(sess, flags, export_dir,但我得到了

INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.

中给出的文件夹export_dir具有以下结构:

+dir
   +saved_model
      -saved_model.pb
   -model.ckpt.data-00000-of-00001
   -model.ckpt.index
   -checkpoint
   -frozen_inference_graph.pb
   -model.ckpt.meta
   -pipeline.config

我的最终目标是用相机捕捉图像,并将它们输入网络以进行实时对象检测。\ 作为中间步骤,现在我只想能够输入单张图片并获得输出。我能够训练网络,但现在我无法使用它。

先感谢您。

4

1 回答 1

2

我找到了一个关于如何下载模型的示例,让我通过它。\ 由于示例中下载的文件的文件夹格式与我的代码相同,因此我只需要对其进行调整。

下载模型的 orifinal 函数是

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

然后我用那个函数创建了这个新函数

def load_local_model(model_path):
  model_dir = pathlib.Path(model_path)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

起初这没有用,因为tf.saved_model.load预期有 3 个参数,但通过在同一个示例中导入两个导入块解决了这个问题,我仍然不知道导入是什么伎俩以及为什么(当我得到它时我会编辑这个答案),但目前这段代码可以工作,这个例子可以做更多的事情。

导入块如下

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

编辑 这个工作真正需要的是以下块。

import os
import pathlib


if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.

%%bash 
cd models/research
pip install .

否则此导入块将不起作用

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
于 2020-04-23T01:08:29.710 回答