4

背景

在混淆矩阵中,对角线表示预测标签与正确标签匹配的情况。所以对角线是好的,而所有其他的单元格都是坏的。为了澄清非专家在 CM 中什么是好的和什么是坏的,我想给对角线一个与其他颜色不同的颜色。我想用Python & Seaborn来实现这一点。

基本上我试图实现这个问题在 R 中的作用(ggplot2 Heatmap 2 Different Color Schemes - Confusion Matrix: Matches in different color Scheme than Missclassifications

带有热图的正常 Seaborn 混淆矩阵

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

sns.heatmap(cf_matrix, annot=True, cmap='Blues')  # cmap='OrRd'

这导致了这个图像:

Seaborn 混淆矩阵与颜色图“蓝调”

目标

我想用例如为非对角线单元格着色cmap='OrRd'。所以我想会有 2 个颜色条,1 个蓝色用于对角线,1 个用于其他单元格。最好两个颜色条的值都匹配(例如,0-70 和 0-70 和 0-40 不匹配)。我将如何处理这个?

以下不是用代码制作的,而是用照片编辑软件制作的:

所需的混淆矩阵配色方案

4

2 回答 2

10

您可以mask=在调用中使用heatmap()来选择要显示的单元格。对对角线和 off_diagonal 单元格使用两个不同的掩码,您可以获得所需的输出:

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

vmin = np.min(cf_matrix)
vmax = np.max(cf_matrix)
off_diag_mask = np.eye(*cf_matrix.shape, dtype=bool)

fig = plt.figure()
sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, cbar_kws=dict(ticks=[]))

在此处输入图像描述

如果您想花哨,可以使用 GridSpec 创建轴以获得更好的布局:

将 numpy 导入为 np 将 seaborn 导入为 sns

fig = plt.figure()
gs0 = matplotlib.gridspec.GridSpec(1,2, width_ratios=[20,2], hspace=0.05)
gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,2, subplot_spec=gs0[1], hspace=0)

ax = fig.add_subplot(gs0[0])
cax1 = fig.add_subplot(gs00[0])
cax2 = fig.add_subplot(gs00[1])

sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax2)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax1, cbar_kws=dict(ticks=[]))

在此处输入图像描述

于 2020-11-12T09:46:19.913 回答
1

您可以先用颜色图“OrRd”绘制热图,然后用颜色图“蓝色”的热图覆盖它,将上下三角形值替换为 NaN,请参见以下示例:

def diagonal_heatmap(m):

    vmin = np.min(m)
    vmax = np.max(m)    
    
    sns.heatmap(cf_matrix, annot=True, cmap='OrRd', vmin=vmin, vmax=vmax)

    diag_nan = np.full_like(m, np.nan, dtype=float)
    np.fill_diagonal(diag_nan, np.diag(m))
    
    sns.heatmap(diag_nan, annot=True, cmap='Blues', vmin=vmin, vmax=vmax, cbar_kws={'ticks':[]}) 




cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

diagonal_heatmap(cf_matrix)
于 2020-11-12T09:46:30.673 回答