3

有很多在线教程可以指导如何使用预训练模型在 Keras 中对一张图像进行一次预测。对于我的情况,我在 Keras 中使用 VGG16 模型,我需要连续预测图像,所以我使用 for 循环加载图像,然后将其传递给预测函数,它运行良好,但是一个预测时间太长(我的 ~800ms机器,仅限 CPU),这里是代码:

    
    # 一个完整的预测函数花费 800ms
    def 预测(图像):
        图像 = img_to_array(图像)
        图像 = 图像 / 255
        图像 = np.expand_dims(图像,轴 = 0)
        # 搭建VGG16网络,这一行代码耗时400~500ms
        模型 = keras.applications.VGG16(include_top=True, weights='imagenet')
        # 做预测
        预测 = 模型.预测(图像)
        '''
        过程预测结果
        '''

    '''
    一些预处理
    '''
    对于 imgs_list 中的 img:
        预测(图片)

上面的代码可以运行良好,但是每次预测花费的时间太长,整个函数需要800ms,构建VGG网络需要500ms,成本太高。我想为连续预测模式的每个预测删除这 500 毫秒。

我尝试将“model = keras.applications.VGG16(include_top=True, weights='imagenet')”这一行代码放在预测函数的外部,全局定义它或将“model”作为参数传递给函数,但程序会返回错误并在第一次成功预测后结束。

回溯(最近一次通话最后):
  _run 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 1075 行
    子馈送,allow_tensor=True,allow_operation=False)
  文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,第 3590 行,在 as_graph_element
    返回 self._as_graph_element_locked(obj,allow_tensor,allow_operation)
  _as_graph_element_locked 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,第 3669 行
    raise ValueError("张量 %s 不是该图的元素。" % obj)
ValueError: Tensor Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) 不是该图的元素。
在处理上述异常的过程中,又出现了一个异常:

回溯(最近一次通话最后):
  文件“multi_classifier.py”,第 256 行,在
    预测(当前文件路径)
  预测中的文件“multi_classifier.py”,第 184 行
    瓶颈_预测=模型_1.预测(图像)
  预测中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py”,第 1835 行
    详细=详细,步骤=步骤)
  _predict_loop 中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py”,第 1331 行
    batch_outs = f(ins_batch)
  __call__ 中的文件“/home/zi/venv/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py”,第 2482 行
    **self.session_kwargs)
  文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 900 行,运行中
    run_metadata_ptr)
  _run 中的文件“/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py”,第 1078 行
    '无法将 feed_dict 键解释为张量:' + e.args[0])
TypeError:无法将 feed_dict 键解释为张量:张量 Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) 不是此图的元素。


似乎我需要为每个预测实例化一个 VGG 模型,如何更改代码以节省模型构建时间?谢谢。

4

0 回答 0