我正在运行此 Github 上提供的代码 - https://github.com/arthurdouillard/CVPR2021_PLOP/blob/381cb795d70ba8431d864e4b60bb84784bc85ec9/metrics/stream_metrics.py
现在我可以查看颜色变化的混淆矩阵,但我看不到实际数字。您建议进行哪些更改以在可视化上获取这些数字?
我正在运行此 Github 上提供的代码 - https://github.com/arthurdouillard/CVPR2021_PLOP/blob/381cb795d70ba8431d864e4b60bb84784bc85ec9/metrics/stream_metrics.py
现在我可以查看颜色变化的混淆矩阵,但我看不到实际数字。您建议进行哪些更改以在可视化上获取这些数字?
您可以更改confusion_matrix_to_fig
为:
import itertools
def confusion_matrix_to_fig(self):
cm = self.confusion_matrix.astype('float') / (self.confusion_matrix.sum(axis=1) +
0.000001)[:, np.newaxis]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(title=f'Confusion Matrix', ylabel='True label', xlabel='Predicted label')
# Adding text to cm
thresh = cm.max() / 1.5 # Thresh is used to decide color of text (white or black)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, "{:0.4f}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return fig
您可能已经注意到,这将绘制混淆矩阵的归一化值。
查看此笔记本以获取更多信息。