0

我已经训练了一个模型(通过 Keras 框架),将它导出,model.save('model.hdf5')现在我想将它与很棒的 Streamlit 集成。显然,我不想在最终用户每次插入新输入时都加载模型,而是一劳永逸地加载它。所以我的代码看起来像这样:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()

    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

这样我得到了:

“streamlit.errors.UnhashableType:”

例外。@st.cache(allow_output_mutation=True)当我在 streamlit 页面上运行查询时,我尝试使用。我有:

“TypeError:无法将 feed_dict 键解释为 Tensor:Tensor Tensor("input_1:0", shape=(?, 80), dtype=int32) 不是该图的元素。”

(当然,没有任何缓存装饰器,模型已加载并且工作正常)

我应该如何正确加载和缓存Keras 训练的模型?

  • Python 版本:2.7(不幸的是)
  • Keras 版本:2.1.3
  • 张量流版本:1.3.0
  • 流光版:0.55.2

非常感谢!

4

1 回答 1

0

解决方案是:

  1. 添加_make_predict_function()通话
  2. 返回会话
from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)
于 2020-04-11T12:53:48.277 回答