我正在构建一个多标签分类器来预测基于文本字段的标签。例如,根据电影标题预测类型。我想用来MultiLabelBinarizer()
对包含所有适用流派标签的列进行二值化。例如,['action','comedy','drama']
被分成三列,值为 0/1。
我使用的原因MultiLabelBinarizer()
是我可以使用内置inverse_transform()
函数来转换输出数组(例如array([0, 0, 1, 0, 1])
直接转换为用户友好的文本输出(['action','drama']
)。
分类器有效,但我在预测新数据时遇到问题。我找不到将它集成MultiLabelBinarizer()
到我的管道中的方法,以便可以保存和重新加载它以推断新数据。一种解决方案是将其单独保存为泡菜对象并每次将其加载回来,但我想避免在生产中产生这种依赖性。
我知道这类似于我在管道中构建的 tf-idf 向量,但不同之处在于它应用于目标列(流派标签)而不是我的自变量(文本注释)。这是我训练多标签 SVM 的代码:
def svm_train(df):
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['Genres'])
with mlflow.start_run():
x_train, x_test, y_train, y_test = train_test_split(df['Movie Title'], y, test_size=0.3)
# Instantiate TF-IDF Vectorizer and SVM Model
tfidf_vect = TfidfVectorizer()
mdl = OneVsRestClassifier(LinearSVC(loss='hinge'))
svm_pipeline = Pipeline([('tfidf', tfidf_vect), ('clf', mdl)])
svm_pipeline.fit(x_train, y_train)
prediction = svm_pipeline.predict(x_test)
report = classification_report(y_test, prediction, target_names=mlb.classes_)
mlflow.sklearn.log_model(svm_pipeline, "Multilabel Classifier")
mlflow.log_artifact(mlb, "MLB")
return(report)
svm_train(df)
推理包括在单独的 Databricks 笔记本中从 MLflow 重新加载保存的模型(与在 pickle 文件中加载相同)并使用管道进行预测:
def predict_labels(new_data):
model_uri = '...MLflow path...'
model = mlflow.sklearn.load_model(model_uri)
predictions = model.predict(new_data)
# If I can't package the MultiLabelBinarizer() into the Pipeline, this
# is where I'd have to load the pickle object mlb
# so that I can inverse_transform()
return mlb.inverse_transform(predictions)
new_data = ['Some movie title']
predict_labels(new_data)
['action','comedy']
这是我正在使用的所有库:
import pandas as pd
import numpy as np
import mlflow
import mlflow.sklearn
import glob, os
from pyspark.sql import DataFrame
from sklearn.pipeline import Pipeline
from sklearn import preprocessing
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn import svm
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, precision_score, recall_score