我找到了一个关于如何下载模型的示例,让我通过它。\ 由于示例中下载的文件的文件夹格式与我的代码相同,因此我只需要对其进行调整。
下载模型的 orifinal 函数是
def load_model(model_name):
base_url = 'http://download.tensorflow.org/models/object_detection/'
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(
fname=model_name,
origin=base_url + model_file,
untar=True)
model_dir = pathlib.Path(model_dir)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
然后我用那个函数创建了这个新函数
def load_local_model(model_path):
model_dir = pathlib.Path(model_path)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
起初这没有用,因为tf.saved_model.load
预期有 3 个参数,但通过在同一个示例中导入两个导入块解决了这个问题,我仍然不知道导入是什么伎俩以及为什么(当我得到它时我会编辑这个答案),但目前这段代码可以工作,这个例子可以做更多的事情。
导入块如下
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display
和
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
编辑
这个工作真正需要的是以下块。
import os
import pathlib
if "models" in pathlib.Path.cwd().parts:
while "models" in pathlib.Path.cwd().parts:
os.chdir('..')
elif not pathlib.Path('models').exists():
!git clone --depth 1 https://github.com/tensorflow/models
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
%%bash
cd models/research
pip install .
否则此导入块将不起作用
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util