33

我有一个包含值的数组,我想创建它的直方图。我主要对低端号码感兴趣,想把每一个300以上的号码都收集到一个箱子里。此 bin 应与所有其他(同样宽的) bin 具有相同的宽度。我怎样才能做到这一点?

注意:这个问题与这个问题有关:Defining bin width/x-axis scale in Matplotlib histogram

这是我到目前为止所尝试的:

import matplotlib.pyplot as plt
import numpy as np

def plot_histogram_01():
    np.random.seed(1)
    values_A = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_B = np.random.choice(np.arange(600), size=200, replace=True).tolist()

    bins = [0, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 600]

    fig, ax = plt.subplots(figsize=(9, 5))
    _, bins, patches = plt.hist([values_A, values_B], normed=1,  # normed is deprecated and will be replaced by density
                                bins=bins,
                                color=['#3782CC', '#AFD5FA'],
                                label=['A', 'B'])

    xlabels = np.array(bins[1:], dtype='|S4')
    xlabels[-1] = '300+'

    N_labels = len(xlabels)
    plt.xlim([0, 600])
    plt.xticks(25 * np.arange(N_labels) + 12.5)
    ax.set_xticklabels(xlabels)

    plt.yticks([])
    plt.title('')
    plt.setp(patches, linewidth=0)
    plt.legend()

    fig.tight_layout()
    plt.savefig('my_plot_01.png')
    plt.close()

这是结果,看起来不太好: 在此处输入图像描述

然后我在其中更改了 xlim 行:

plt.xlim([0, 325])

结果如下: 在此处输入图像描述

它看起来或多或少像我想要的那样,但最后一个垃圾箱现在不可见。我缺少哪个技巧来可视化宽度为 25 的最后一个 bin?

4

2 回答 2

46

Numpy 有一个方便的函数来处理这个问题:np.clip. 尽管名称听起来像什么,但它不会删除值,它只是将它们限制在您指定的范围内。基本上,它内嵌了 Artem 的“肮脏黑客”。您可以将值保留原样,但在hist调用中,只需将数组包装在np.clip调用中,就像这样

plt.hist(np.clip(values_A, bins[0], bins[-1]), bins=bins)

这更好,原因有很多:

  1. 它的速度要快得多——至少对于大量元素来说是这样。Numpy 在 C 级别完成其工作。对 python 列表进行操作(如在 Artem 的列表理解中)对每个元素都有很多开销。基本上,如果你可以选择使用 numpy,你应该这样做。

  2. 您可以在需要的地方正确执行,从而减少代码出错的机会。

  3. 您不需要保留数组的第二个副本,这可以减少内存使用(除了这一行)并进一步减少出错的机会。

  4. 使用bins[0], bins[-1]而不是硬编码这些值可以减少再次出错的机会,因为您可以在bins定义的地方更改 bin;您无需记住在呼叫clip或其他任何地方更改它们。

因此,将它们放在一起,就像在 OP 中一样:

import matplotlib.pyplot as plt
import numpy as np

def plot_histogram_01():
    np.random.seed(1)
    values_A = np.random.choice(np.arange(600), size=200, replace=True)
    values_B = np.random.choice(np.arange(600), size=200, replace=True)

    bins = np.arange(0,350,25)

    fig, ax = plt.subplots(figsize=(9, 5))
    _, bins, patches = plt.hist([np.clip(values_A, bins[0], bins[-1]),
                                 np.clip(values_B, bins[0], bins[-1])],
                                # normed=1,  # normed is deprecated; replace with density
                                density=True,
                                bins=bins, color=['#3782CC', '#AFD5FA'], label=['A', 'B'])

    xlabels = bins[1:].astype(str)
    xlabels[-1] += '+'

    N_labels = len(xlabels)
    plt.xlim([0, 325])
    plt.xticks(25 * np.arange(N_labels) + 12.5)
    ax.set_xticklabels(xlabels)

    plt.yticks([])
    plt.title('')
    plt.setp(patches, linewidth=0)
    plt.legend(loc='upper left')

    fig.tight_layout()
plot_histogram_01()

上面代码的结果

于 2015-05-18T14:04:21.337 回答
5

对不起,我不熟悉 matplotlib。所以我有一个肮脏的黑客给你。我只是将所有大于 300 的值放在一个 bin 中并更改了 bin 大小。

问题的根源在于 matplotlib 试图将所有的 bin 放在图上。在 RI 中会将我的 bin 转换为因子变量,因此它们不会被视为实数。

import matplotlib.pyplot as plt
import numpy as np

def plot_histogram_01():
    np.random.seed(1)
    values_A = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_B = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_A_to_plot = [301 if i > 300 else i for i in values_A]
    values_B_to_plot = [301 if i > 300 else i for i in values_B]

    bins = [0, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325]

    fig, ax = plt.subplots(figsize=(9, 5))
    _, bins, patches = plt.hist([values_A_to_plot, values_B_to_plot], normed=1,  # normed is deprecated and will be replaced by density
                                bins=bins,
                                color=['#3782CC', '#AFD5FA'],
                                label=['A', 'B'])

    xlabels = np.array(bins[1:], dtype='|S4')
    xlabels[-1] = '300+'

    N_labels = len(xlabels)

    plt.xticks(25 * np.arange(N_labels) + 12.5)
    ax.set_xticklabels(xlabels)

    plt.yticks([])
    plt.title('')
    plt.setp(patches, linewidth=0)
    plt.legend()

    fig.tight_layout()
    plt.savefig('my_plot_01.png')
    plt.close()

plot_histogram_01()

在此处输入图像描述

于 2014-10-06T15:02:59.007 回答