有很多在线教程可以指导如何使用预训练模型在 Keras 中对一张图像进行一次预测。对于我的情况,我在 Keras 中使用 VGG16 模型,我需要连续预测图像,所以我使用 for 循环加载图像,然后将其传递给预测函数,它运行良好,但是一个预测时间太长(我的 ~800ms机器,仅限 CPU),这里是代码:
# 一个完整的预测函数花费 800ms def 预测(图像): 图像 = img_to_array(图像) 图像 = 图像 / 255 图像 = np.expand_dims(图像,轴 = 0) # 搭建VGG16网络,这一行代码耗时400~500ms 模型 = keras.applications.VGG16(include_top=True, weights='imagenet') # 做预测 预测 = 模型.预测(图像) ''' 过程预测结果 ''' ''' 一些预处理 ''' 对于 imgs_list 中的 img: 预测(图片)
上面的代码可以运行良好,但是每次预测花费的时间太长,整个函数需要800ms,构建VGG网络需要500ms,成本太高。我想为连续预测模式的每个预测删除这 500 毫秒。
我尝试将“model = keras.applications.VGG16(include_top=True, weights='imagenet')”这一行代码放在预测函数的外部,全局定义它或将“model”作为参数传递给函数,但程序会返回错误并在第一次成功预测后结束。
回溯(最近一次通话最后): _run 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 1075 行 子馈送,allow_tensor=True,allow_operation=False) 文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,第 3590 行,在 as_graph_element 返回 self._as_graph_element_locked(obj,allow_tensor,allow_operation) _as_graph_element_locked 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,第 3669 行 raise ValueError("张量 %s 不是该图的元素。" % obj) ValueError: Tensor Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) 不是该图的元素。 在处理上述异常的过程中,又出现了一个异常: 回溯(最近一次通话最后): 文件“multi_classifier.py”,第 256 行,在 预测(当前文件路径) 预测中的文件“multi_classifier.py”,第 184 行 瓶颈_预测=模型_1.预测(图像) 预测中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py”,第 1835 行 详细=详细,步骤=步骤) _predict_loop 中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py”,第 1331 行 batch_outs = f(ins_batch) __call__ 中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py”,第 2482 行 **self.session_kwargs) 文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 900 行,运行中 run_metadata_ptr) _run 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 1078 行 '无法将 feed_dict 键解释为张量:' + e.args[0]) TypeError:无法将 feed_dict 键解释为张量:张量 Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) 不是此图的元素。
似乎我需要为每个预测实例化一个 VGG 模型,如何更改代码以节省模型构建时间?谢谢。