我需要在 Pylab 中创建一个图表,在其中我将沿 Y 轴绘制彩色点。y 轴从 0 到 100。我还有一个包含 100 个元素的列表,这些元素是 +1 或 -1。该列表必须与图表的 Y 轴相对应。
例如,如果列表中的第五个元素是 +1,我需要在 Y 轴上的 y=5 上绘制一个绿点。如果列表中的第五个元素是 -1,则该点必须是红色的。
我必须对列表中的所有元素执行此操作。
我在 Pylab 中绘制了简单的图表,但在这种情况下我完全迷失了。任何帮助将不胜感激。谢谢!!
我需要在 Pylab 中创建一个图表,在其中我将沿 Y 轴绘制彩色点。y 轴从 0 到 100。我还有一个包含 100 个元素的列表,这些元素是 +1 或 -1。该列表必须与图表的 Y 轴相对应。
例如,如果列表中的第五个元素是 +1,我需要在 Y 轴上的 y=5 上绘制一个绿点。如果列表中的第五个元素是 -1,则该点必须是红色的。
我必须对列表中的所有元素执行此操作。
我在 Pylab 中绘制了简单的图表,但在这种情况下我完全迷失了。任何帮助将不胜感激。谢谢!!
import matplotlib.pyplot as plt
import numpy as np
data = np.array([1,1,-1,-1,1])
cmap = np.array([(1,0,0), (0,1,0)])
uniqdata, idx = np.unique(data, return_inverse=True)
N = len(data)
fig, ax = plt.subplots()
plt.scatter(np.zeros(N), np.arange(1, N+1), s=100, c=cmap[idx])
plt.grid()
plt.show()
产量
解释:
如果你打印出来np.unique(data, return_inverse=True)
,你会看到它返回一个数组元组:
In [71]: np.unique(data, return_inverse=True)
Out[71]: (array([-1, 1]), array([1, 1, 0, 0, 1]))
第一个数组表示其中的唯一值data
是 -1 和 1。第二个数组在data
-1 和 1 的任何地方分配值 0 和data
1。本质上,np.unique
允许我们转换[1,1,-1,-1,1]
为[1, 1, 0, 0, 1]
。现在cmap[idx]
是一个 RGB 值数组:
In [74]: cmap[idx]
Out[74]:
array([[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0]])
这是 NumPy 数组上所谓的“花式索引”的应用。cmap[0]
是 的第一行cmap
。cmap[1]
是第二行cmap
。cmap[idx]
是一个数组,其中的第 i 个元素cmap[idx]
是cmap[idx[i]]
. 所以,你最终cmap[idx]
会成为一个二维数组,其中第 i 行是cmap[idx[i]]
。因此cmap[idx]
可以被认为是一系列 RGB 颜色值。
如果您有多个点并且希望将它们绘制成列,我能想到的最简单的方法是ax.scatter
为每个列表调用一次data
:
import matplotlib.pyplot as plt
import numpy as np
def plot_data(ax, data, xval):
N = len(data)
uniqdata, idx = np.unique(data, return_inverse=True)
ax.scatter(np.ones(N)*xval, np.arange(1, N+1), s=100, c=cmap[idx])
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
data = np.array([1,1,-1,-1,1])
data2 = np.array([1,-1,1,1,-1])
plot_data(ax, data, 0)
plot_data(ax, data2, 1)
plt.grid()
plt.show()
这样做的好处是它相对容易理解。这样做的坏处是它ax.scatter
不止一次调用。如果您有大量数据集,则整理数据并调用ax.scatter
一次会更有效。这对于 Matplotlib 来说更快,但它的代码有点复杂:
import matplotlib.pyplot as plt
import numpy as np
import itertools as IT
def plot_dots(ax, datasets):
N = sum(len(data) for data in datasets)
x = np.fromiter(
(i for i, data in enumerate(datasets) for j in np.arange(len(data))),
dtype='float', count=N)
y = np.fromiter(
(j for data in datasets for j in np.arange(1, len(data)+1)),
dtype='float', count=N)
c = np.fromiter(
(val for data in datasets
for rgb in cmap[np.unique(data, return_inverse=True)[-1]]
for val in rgb),
dtype='float', count=3*N).reshape(-1,3)
ax.scatter(x, y, s=100, c=c)
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
N = 100
datasets = [np.random.randint(2, size=5) for i in range(N)]
plot_dots(ax, datasets)
plt.grid()
plt.show()
参考: