0

我使用以下代码:

fpr, tpr, t = roc_curve(true_categories, predicted_categories)

我知道predict_categories应该是概率或置信度。

但是由于 predict_proba() 似乎不适用于 keras 功能 API,我应该如何正确获取 predict_categories?

我试过了

predicted_categories = tf.argmax(y_pred, axis=1)

ROC 看起来像这样:

锐角ROC

和这个

predicted_categories =  = tf.reduce_max(y_pred, axis=1, keepdims=True)

ROC 看起来像这样:

AUC ~50% 的 ROC,应该是 85%

两个都不是我想象中的...

任何建议表示赞赏!

4

0 回答 0