1

我使用 Matplolib 和 Pandas 数据框创建了一个散点图,现在我想为其添加一个图例。这是我的代码:

colors = ['red' if x >= 150 and x < 200 else 
          'green' if x >= 200 and x < 400 else
          'purple' if x >= 400 and x < 600 else
          'yellow' if x >= 600 else 'teal' for x in myData.R]


ax1.scatter(myData.X, myData.Y, s=20, c=colors, marker='_', label='Test')
ax1.legend(loc='upper left', frameon=False)

这里发生的是,根据 的值myData.R,散点图中点的颜色会发生变化。所以,由于颜色是“动态的”,我在创建图例时遇到了很多麻烦。实际代码只会创建一个带有一个名为“测试”的标签的图例,附近没有任何颜色。

以下是数据示例:

       X  Y    R
0      1  945  1236.334519
0      1  950   212.809352
0      1  950   290.663847
0      1  961   158.156856

我试过这个,但我不明白的是:

  1. 如何动态为图例设置标签?例如,我的代码说'red' if x >= 150,所以在图例上应该有一个红色方块,其附近> 150。但由于我没有手动添加任何标签,我很难理解这一点。

  2. 在尝试了以下之后,我只得到了一个带有单个标签“类”的图例:

`legend1 = ax1.legend(*scatter.legend_elements(), loc="左下", title="类")

ax1.add_artist(legend1)`

任何形式的建议表示赞赏!

4

1 回答 1

2

可以加速的部分代码是使用纯 Python 循环创建字符串列表。Pandas 使用 numpy 的过滤非常有效。绘制散点图主要取决于点的数量,当所有点一次绘制或分五部分绘制时,点数不会改变。

在循环中使用 matplotlib 的 scatter 的一些示例代码:

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

N = 500
myData = pd.DataFrame({'X': np.round(np.random.uniform(-1000, 1000, N), -2), 
                       'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)

fig, ax1 = plt.subplots()

bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
for b0, b1, col in zip([None]+bounds, bounds+[None], colors):
    if b0 is None:
        filter = (myData.R < b1)
        label = f'$ R < {b1} $'
    elif b1 is None:
        filter = (myData.R >= b0)
        label = f'${b0} ≤ R $'
    else:
        filter = (myData.R >= b0) & (myData.R < b1)
        label = f'${b0} ≤ R < {b1}$'
    ax1.scatter(myData.X[filter], myData.Y[filter], s=20, c=col, marker='_', label=label)
ax1.legend()
plt.show()

matplotlib 的散点图

或者,cut可以使用 pandas 来创建类别,而 seaborn 的功能(例如其hue参数)可以进行着色并自动创建图例。

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

N = 500
myData = pd.DataFrame({'X': np.round( np.random.uniform(-1000, 1000, N),-2), 'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)

fig, ax1 = plt.subplots()

bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']

hues = pd.cut(myData.R, [0]+bounds+[2000], right=False)
sns.scatterplot(myData.X, myData.Y, hue=hues, hue_order=hues.cat.categories, palette=colors, s=20, marker='_', ax=ax1)
plt.show()

seaborn 散点图

于 2020-06-24T10:58:43.430 回答