2

我正在使用 keras 的预训练模型,但在尝试进行预测时出现了错误。我在烧瓶服务器中有以下代码:

from NeuralNetwork import *

@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
    uploaded_files = request.files.getlist("file[]")
    getPredictionfunction = preTrainedModel["VGG16"]

    for file in uploaded_files:
        path = os.path.join(STATIC_PATH, file.filename)
        result = getPredictionfunction(path)

这就是我在 NeuralNetwork.py 文件中的内容:

vgg16 = VGG16(weights='imagenet', include_top=True)
def getVGG16Prediction(img_path):

    model = vgg16
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    pred = model.predict(x) #ERROR HERE
    return sort(decode_predictions(pred, top=3)[0])

preTrainedModel["VGG16"] = getVGG16Prediction

但是,在下面运行此代码不会产生任何错误:

if __name__ == "__main__":
    STATIC_PATH = os.getcwd()+"/static"
    print(preTrainedModel["VGG16"](STATIC_PATH+"/18.jpg"))

这是完整的错误: 在此处输入图像描述

非常感谢任何意见或建议。谢谢你。

4

2 回答 2

3

考虑到后端设置为 tensorflow。您应该将 Keras 会话设置为张量流图

from tensorflow import Graph, Session
from keras import backend 
model = 'model path'
graph1 = Graph()
with graph1.as_default():
    session1 = Session(graph=graph1)
    with session1.as_default():
        model_1.load_model(model) # depends on your model type

model2 = 'model path2'
graph2 = Graph()
with graph2.as_default():
    session2 = Session(graph=graph2)
    with session2.as_default():
        model_2.load_model(model2) # depends on your model type

并用于预测

K.set_session(session#)
with graph#.as_default():
    prediction = model_#.predict(img_data)
于 2018-07-25T23:26:36.307 回答
1

编辑:我在下面写的在部署应用程序时似乎不起作用(直到现在我只是在本地测试)。app.config 中的模型似乎加载得太频繁了。(应每个要求?)

巧合的是,我昨天也遇到了同样的问题。TensorFlow 和 Flask 的交互似乎有些问题。不幸的是,我对其中任何一个的内部知识都不够了解,无法真正理解这个问题,但我可以提供一个帮助我让它工作的技巧。(注意:我使用的是 Python3,但我认为这在这里没有什么不同。)

在烧瓶应用程序的全局命名空间中初始化模型时似乎会​​出现问题。因此我将模型直接加载到 app.config 中:

app.config.update({"MODEL":VGG16(weights='imagenet', include_top=True)})
# ...
app.config["MODEL"].predict(x)

也许您可以在 server.py 而不是 NeuralNetwork.py 中加载模型并将其getVGG16Predictionimg_path?

于 2017-02-08T10:50:39.233 回答