0

我需要在 Pylab 中创建一个图表,在其中我将沿 Y 轴绘制彩色点。y 轴从 0 到 100。我还有一个包含 100 个元素的列表,这些元素是 +1 或 -1。该列表必须与图表的 Y 轴相对应。

例如,如果列表中的第五个元素是 +1,我需要在 Y 轴上的 y=5 上绘制一个绿点。如果列表中的第五个元素是 -1,则该点必须是红色的。

我必须对列表中的所有元素执行此操作。

我在 Pylab 中绘制了简单的图表,但在这种情况下我完全迷失了。任何帮助将不胜感激。谢谢!!

4

1 回答 1

2
import matplotlib.pyplot as plt
import numpy as np

data = np.array([1,1,-1,-1,1])
cmap = np.array([(1,0,0), (0,1,0)])
uniqdata, idx = np.unique(data, return_inverse=True)

N = len(data)
fig, ax = plt.subplots()
plt.scatter(np.zeros(N), np.arange(1, N+1), s=100, c=cmap[idx])
plt.grid()
plt.show()

产量

在此处输入图像描述


解释:

如果你打印出来np.unique(data, return_inverse=True),你会看到它返回一个数组元组:

In [71]: np.unique(data, return_inverse=True)
Out[71]: (array([-1,  1]), array([1, 1, 0, 0, 1]))

第一个数组表示其中的唯一值data是 -1 和 1。第二个数组在data-1 和 1 的任何地方分配值 0 和data1。本质上,np.unique允许我们转换[1,1,-1,-1,1][1, 1, 0, 0, 1]。现在cmap[idx]是一个 RGB 值数组:

In [74]: cmap[idx]
Out[74]: 
array([[0, 1, 0],
       [0, 1, 0],
       [1, 0, 0],
       [1, 0, 0],
       [0, 1, 0]])

这是 NumPy 数组上所谓的“花式索引”的应用。cmap[0]是 的第一行cmapcmap[1]是第二行cmapcmap[idx]是一个数组,其中的第 i 个元素cmap[idx]cmap[idx[i]]. 所以,你最终cmap[idx]会成为一个二维数组,其中第 i 行是cmap[idx[i]]。因此cmap[idx]可以被认为是一系列 RGB 颜色值。


如果您有多个点并且希望将它们绘制成列,我能想到的最简单的方法是ax.scatter为每个列表调用一次data

import matplotlib.pyplot as plt
import numpy as np

def plot_data(ax, data, xval):
    N = len(data)
    uniqdata, idx = np.unique(data, return_inverse=True)
    ax.scatter(np.ones(N)*xval, np.arange(1, N+1), s=100, c=cmap[idx])

cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()

data = np.array([1,1,-1,-1,1])
data2 = np.array([1,-1,1,1,-1])

plot_data(ax, data, 0)
plot_data(ax, data2, 1)

plt.grid()
plt.show()

在此处输入图像描述

这样做的好处是它相对容易理解。这样做的坏处是它ax.scatter不止一次调用。如果您有大量数据集,则整理数据并调用ax.scatter一次会更有效。这对于 Matplotlib 来说更快,但它的代码有点复杂:

import matplotlib.pyplot as plt
import numpy as np
import itertools as IT

def plot_dots(ax, datasets):
    N = sum(len(data) for data in datasets)
    x = np.fromiter(
        (i for i, data in enumerate(datasets) for j in np.arange(len(data))),
        dtype='float', count=N)
    y = np.fromiter(
        (j for data in datasets for j in np.arange(1, len(data)+1)),
        dtype='float', count=N)
    c = np.fromiter(
        (val for data in datasets
         for rgb in cmap[np.unique(data, return_inverse=True)[-1]]
         for val in rgb),
        dtype='float', count=3*N).reshape(-1,3)
    ax.scatter(x, y, s=100, c=c)

cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()

N = 100
datasets = [np.random.randint(2, size=5) for i in range(N)]

plot_dots(ax, datasets)

plt.grid()
plt.show()

在此处输入图像描述


参考:

于 2013-06-23T17:38:17.210 回答