我想将我的 ML 模型保存到我的本地机器上。我已经按照https://stackoverflow.com/a/29291153/11543494这个答案将 ML 模型存储在本地机器中,但是当我从本地机器加载持久的 ML 模型时,我得到了关键错误。
我已经创建了可以预测 URL 类别的 ML 模型。现在我想将 ML 模型与我的 Web 应用程序集成,这就是我使用 Flask 创建 API 的原因。
我已经在 jupyter notebook 中测试了我的 ML 模型,现在我的代码与 ML 模型相关,我只想转储我的 ML 模型并在我的 API 中使用它。在 jupyter notebook 中,我得到了正确的输出,但是当我在我的 API 代码中加载持久文件时,我得到了 KeyError。我尝试使用pickle,joblib,但出现MemoryError,我也尝试解决该问题,但无法解决该问题,因此我正在尝试Klepto。
盗贼密码
from klepto.archives import dir_archive
model = dir_archive('E:/Mayur/Sem 5/Python project/model_klepto',{'result':gs_clf},serialized=True, cached=False)
#gs_clf = gs_clf.fit(x_train, y_train) #RandomizedSearchCV
model.dump()
API 代码
import numpy as np
from flask import Flask, request, jsonify, render_template
from klepto.archives import dir_archive
app = Flask(__name__)
demo = dir_archive(
'E:/Mayur/Sem 5/Python project/model_klepto', {}, serialized=True, cached=False)
demo.load()
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
input = request.form.values()
final_feature = [np.array(input)]
prediction = demo['result'].predict([str(final_feature)])
return render_template('index.html', prediction_text=prediction)
if __name__ == "__main__":
app.run(debug=True)
当我运行 API 时,我得到 KeyError:'result'。
如果我在 jupyter notebook 中运行以下代码,我会得到正确的输出
demo = dir_archive(
'E:/Mayur/Sem 5/Python project/model_klepto', {}, serialized=True, cached=False)
demo.load()
demo
输出>
ir_archive('model_klepto', {'result': RandomizedSearchCV(cv='warn', error_score='raise-deprecating',
estimator=Pipeline(memory=None,
steps=[('vect',
CountVectorizer(analyzer='word',
binary=False,
decode_error='strict',
dtype=<class 'numpy.int64'>,
encoding='utf-8',
input='content',
lowercase=True,
max_df=1.0,
max_features=None,
min_df=1,
ngram_range=(1,
1),
preprocessor=None,
stop_words=None,
strip_accen...
sublinear_tf=False,
use_idf=True)),
('clf',
MultinomialNB(alpha=1.0,
class_prior=None,
fit_prior=True))],
verbose=False),
iid='warn', n_iter=5, n_jobs=None,
param_distributions={'clf__alpha': (0.01, 0.001),
'tfidf__use_idf': (True, False),
'vect__ngram_range': [(1, 1), (1, 2)]},
pre_dispatch='2*n_jobs', random_state=None, refit=True,
return_train_score=False, scoring=None, verbose=0)}, cached=False)
demo['result'].predict(['http://www.windows.com'])
输出> 数组(['计算机'],dtype =
这是堆栈跟踪的屏幕截图 Stack trace