1

我试图在 python 中的 knn 分类器上创建一个混淆矩阵,但标记的类是错误的。

数据集的 classes 属性是 2(良性)和 4(恶性),但是当我绘制混淆矩阵时,所有标签都是 2。我使用的代码是:

数据来源: http: //archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Diagnostic%29

来自 UCI 的威斯康星乳腺癌(诊断)数据集上的 KNN 分类器:

data = pd.read_csv('/breast-cancer-wisconsin.data')
data.replace('?', 0, inplace=True)
data.drop('id', 1, inplace = True)


X = np.array(data.drop(' class ', 1))
Y = np.array(data[' class '])

X_train, X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.2)
clf = neighbors.KNeighborsClassifier()
clf.fit(X_train, Y_train)

accuracy = clf.score(X_test, Y_test)

绘制混淆矩阵

from sklearn.metrics import plot_confusion_matrix

disp = plot_confusion_matrix(clf, X_test, Y_test,
                               display_labels=Y,
                               cmap=plt.cm.Blues,)

混淆矩阵

4

1 回答 1

1

问题是您使用 指定display_labels参数Y,它应该只是用于绘图的目标名称。现在它只是使用出现在 中的前两个值,Y恰好是2, 2。还要注意,如docs中所述,显示的标签将与提供时指定的标签相同labels,因此您只需要:

from sklearn.metrics import plot_confusion_matrix
fig, ax = plt.subplots(figsize=(8,8))
disp = plot_confusion_matrix(clf, X_test, Y_test,
                               labels=np.unique(y),
                               cmap=plt.cm.Blues,ax=ax)

在此处输入图像描述

于 2020-05-14T08:11:15.173 回答