3

我使用 pycaret 作为我的 ML 工作流程,我尝试使用 FastAPI 创建一个 API。这是我第一次进入生产级别,所以我对 API 有点困惑

我有 10 个特征;年龄:float,live_province:str,live_city:str,live_area_big:str,live_area_small:str,sex:float,marital:float,bank:str,salary:float,amount:float和一个标签,其中包含二进制值(0和 1)。

这就是我构建 API 的脚本

from pydantic import BaseModel
import numpy as np
from pycaret.classification import *

import uvicorn
from fastapi import FastAPI

app = FastAPI()

model = load_model('catboost_cm_creditable')

class Data(BaseModel):
    age: float
    live_province: str
    live_city: str
    live_area_big: str
    live_area_small: str
    sex: float
    marital: float
    bank: str
    salary: float
    amount: float

input_dict = Data

@app.post("/predict")
def predict(model, input_dict):
    predictions_df = predict_model(estimator=model, data=input_dict)
    predictions = predictions_df['Score'][0]
    return predictions

当我尝试运行uvicorn script:app并转到文档时,我找不到我的功能的参数,参数只显示模型和 input_dict 在此处输入图像描述

如何将我的功能带到 API 中的参数上?

4

2 回答 2

2

您需要键入提示您的 Pydantic 模型以使其与您的 FastAPI 一起使用

想象一下,当您需要记录该函数时,您真的在使用标准 Python,

def some_function(price: int) ->int:
    return price

Pydantic与上面的例子没有什么不同

class Data其实是一条@dataclass拥有超能力的蟒蛇(来自Pydantic)

from fastapi import Depends

class Data(BaseModel):
    age: float
    live_province: str
    live_city: str
    live_area_big: str
    live_area_small: str
    sex: float
    marital: float
    bank: str
    salary: float
    amount: float


@app.post("/predict")
def predict(data: Data = Depends()):
    predictions_df = predict_model(estimator=model, data=data)
    predictions = predictions_df["Score"][0]
    return predictions

有一个小技巧,使用Depends,您将获得一个查询,就像您分别定义每个字段时一样。

取决于

在此处输入图像描述

不依赖

在此处输入图像描述

于 2020-08-31T15:20:50.103 回答
1

您的问题在于 API 函数的定义。您为数据输入添加了一个参数,但您没有告诉 FastAPI 它的类型。另外我假设您的意思是不要使用全局加载的模型,而不是将其作为参数接收。此外,您不需要为输入数据创建全局实例,因为您想从用户那里获取它。

因此,只需将函数的签名更改为:

def predict(input_dict: Data):

并删除该行:

input_dict = Data

(这只是为您的班级创建了一个别名Data,名为input_dict

你最终会得到:

app = FastAPI()

model = load_model('catboost_cm_creditable')

class Data(BaseModel):
    age: float
    live_province: str
    live_city: str
    live_area_big: str
    live_area_small: str
    sex: float
    marital: float
    bank: str
    salary: float
    amount: float

@app.post("/predict")
def predict(input_dict: Data):
    predictions_df = predict_model(estimator=model, data=input_dict)
    predictions = predictions_df['Score'][0]
    return predictions

另外,我建议将类的名称更改为Data更清晰、更易于理解的名称,甚至DataUnit在我看来会更好,因为Data它太笼统了。

于 2020-08-31T15:08:25.360 回答