我在 colab 上使用 simpletransformers 模型训练了变压器,下载了序列化模型,我在使用它进行推理方面几乎没有问题。在 jupyter 上的模型上加载模型是可行的,但是在将它与 fastapi 一起使用时会出现错误这就是我在 jupyter 上使用它的方式:
from scipy.special import softmax
label_cols = ['art', 'politics', 'health', 'tourism']
model = torch.load("model.bin")
pred = model.predict(['i love politics'])[1]
preds = softmax(pred,axis=1)
preds
它给出以下结果:array([[0.00230123, 0.97465035, 0.00475409, 0.01829433]])
我尝试按如下方式使用 fastapi,但不断出现错误:
from pydantic import BaseModel
class Message(BaseModel):
text : str
model = torch.load("model.bin")
@app.post("/predict")
def predict_health(data: Message):
prediction = model.predict(data.text)[1]
preds = softmax(prediction, axis=1)
return {"results": preds}