2

我正在使用高斯朴素贝叶斯从 Pandas 数据框中训练模型,但在使用precision_recall_curve 时出现错误。文档说precision_recall_curve将预测概率作为输入(至少在我阅读时)所以我希望下面的工作(xtrain和xtest分别是具有736和184行的Pandas数据帧;ytrain / ytest是具有736和184的系列行分别):

nb = GaussianNB()
nb.fit(xtrain, ytrain)
predicted = nb.predict_proba(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted)

我希望上述方法能够正常工作,但是我收到“IndexError:索引 230 超出大小 184 的范围”。如果我改为:

predicted = nb.predict(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted)

然后它正确执行。184 是 xtest 和 ytest 中的行数,但 230 不是任何这些结构的维度。有人可以解释差异或我应该如何为此目的使用precision_recall_curve?

4

1 回答 1

1

如果这是二进制分类,请尝试使用以下内容,

predicted = nb.predict_proba(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted[:,1])
于 2013-08-23T18:06:50.407 回答