当我问这个问题时,我没有足够的 python(或 scikit-learn)知识来回答。分类报告(如 prashant0598 所建议的)接近我需要的,尽管它实际上并不具有准确性。以下是分类报告的使用方法:
from sklearn.metrics import classification_report
import pandas as pd
y_pred = model.predict(val_ds)
y_pred = np.argmax(y_pred, axis=1)
y_true = np.concatenate([y for x, y in val_ds], axis=0)
cr = classification_report(y_true, y_pred, output_dict=True, target_names=class_names)
pd.DataFrame.from_dict(cr)
分类报告输出(除其他外)精度和召回率,这会有所帮助。
为了得到类的准确性,我们必须更多地手动执行此操作。这是一种方法:
from sklearn.metrics import accuracy_score
def class_accuracy(class_no):
pred_filter = y_true==class_no
acc = accuracy_score(y_true[pred_filter], y_pred[pred_filter])
return acc
{class_name: class_accuracy(i) for i, class_name in enumerate(class_names)}
{“雏菊”:0.6589147286821705,
“蒲公英”:0.75,
“玫瑰”:0.6,
“向日葵”:0.868421052631579,
“郁金香”:0.6942675159235668}
所以现在我知道了,向日葵是最容易预测的,而玫瑰则特别棘手!