我有一个时间序列的数据框,其中列是时间值(按顺序),每一行都是一个单独的系列。我还有额外的列来给出每一行的类别,这反过来又决定了线型和颜色。
这是数据框:
>>> df
cat (frac_norm, 2, 1) cluster
month_rel -5 -4 -3 -2 -1 0 1 2 3 4 5
user1 user2
3414845 4232621 -1b 0.760675 0.789854 0.95941 0.867755 0.790102 1 0.588729 0.719073 0.695572 0.647696 0.656323 4
4369232 3370279 -1b 0.580436 0.546761 0.71343 0.742033 0.802198 0.389957 0.861451 0.651786 0.798265 0.476305 0.896072 0
22771 3795428 -1b 0.946188 0.499531 0.834885 0.825772 0.754018 0.67823 0.430692 0.353989 0.333761 0.284759 0.260501 2
2660226 3126314 -1b 0.826701 0.81203 0.765182 0.680162 0.763475 0.802632 1 0.780186 0.844019 0.868698 0.722672 4
4154510 4348009 -1b 1 0.955656 0.677647 0.911556 0.76613 0.743759 0.61798 0.606536 0.715528 0.614902 0.482267 3
2860801 164553 -1b 0.870056 0.371981 0.640212 0.835185 0.673108 0.536585 1 0.850242 0.551198 0.873016 0.635556 4
120577 3480468 -1b 0.8197 0.879873 0.961178 1 0.855465 0.827824 0.827139 0.304011 0.574978 0.473996 0.358934 3
6692132 5095003 -1b 1 0.995859 0.738418 0.991217 0.854336 0.936518 0.910347 0.883205 0.987796 0.699433 0.815072 4
2515737 4263756 -1b 0.949047 0.990238 0.899524 1 0.961066 0.83703 0.835114 0.759142 0.749727 0.886913 0.936961 4
707596 2856619 -1b 0.780538 0.702179 0.568627 1 0.601382 0.789116 0 0.0714286 0 0.111969 0.0739796 2
我可以制作下面的图,其中 x 轴是 的有序值('frac_norm',2,1),颜色取决于 的值cluster,线条样式取决于 的值cat。但是,它是逐行的。有没有办法通过使用 groupby 来对它进行矢量化?
我生成图像的代码
import pandas as pd
import numpy as np
colors = ['r','g','b','c','y','k']
lnst = ['-','--']
cats = np.sort(df['cat'].unique())
clusters = np.sort(df['cluster'].unique())
colordict = dict(zip(clusters, colors))
lnstdict = dict(zip(cats,lnst))
fig, ax = plt.subplots()
# I first do it by `cluster` value
for clus_val in clusters:
clr = colordict[clus_val]
subset = df[df['cluster'] == clus_val]
# and then plot each row individually, setting the color and linestyle
for row in subset.iterrows():
ax.plot(row[1][('frac_norm', 2, 1)], color=clr,
linestyle=lnstdict[row[1]['cat'][0]]
)
生成df的代码
import pandas as pd
import numpy as np
vals = np.array([['-1b', 0.7606747496046389, 0.7898535589129476, 0.959409594095941,
0.8677546569280126, 0.7901020186672455, 1.0, 0.5887286145588728,
0.7190726452719073, 0.6955719557195572, 0.6476962793343348,
0.6563233814156323, 4],
['-1b', 0.5804363905325444, 0.5467611336032389,
0.7134300126103406, 0.7420329670329671, 0.8021978021978022,
0.389957264957265, 0.861451048951049, 0.6517857142857143,
0.798265460030166, 0.4763049450549451, 0.8960720130932898, 0],
['-1b', 0.9461875843454791, 0.49953095684803, 0.8348848603625673,
0.8257715338553662, 0.7540183696900115, 0.6782302664655606,
0.43069179143004643, 0.35398860398860393, 0.33376068376068374,
0.28475935828877, 0.260501012145749, 2],
['-1b', 0.8267008985879333, 0.8120300751879698,
0.7651821862348178, 0.680161943319838, 0.7634749524413443,
0.8026315789473684, 1.0, 0.7801857585139319, 0.8440191387559809,
0.8686980609418281, 0.7226720647773278, 4],
['-1b', 1.0, 0.955656108597285, 0.6776470588235294,
0.9115556882651537, 0.766129636568003, 0.7437589670014347,
0.6179800221975582, 0.6065359477124183, 0.715527950310559,
0.6149019607843138, 0.4822670674109059, 3],
['-1b', 0.8700564971751412, 0.3719806763285024,
0.6402116402116402, 0.8351851851851851, 0.6731078904991948,
0.5365853658536585, 1.0, 0.8502415458937197, 0.55119825708061,
0.873015873015873, 0.6355555555555555, 4],
['-1b', 0.8196997807387418, 0.879872907246731, 0.961178456344944,
1.0, 0.8554654738607772, 0.8278240873814314, 0.8271388025408839,
0.3040112596762843, 0.5749778172138421, 0.47399605003291634,
0.35893441346004046, 3],
['-1b', 1.0, 0.9958592132505176, 0.7384176764076977,
0.9912165129556433, 0.8543355440923606, 0.9365176566646254,
0.9103471520053926, 0.8832054560954816, 0.9877955758962623,
0.6994328922495274, 0.8150724637681159, 4],
['-1b', 0.9490474080638015, 0.9902376128200405,
0.8995240613432046, 1.0, 0.9610655737704917, 0.837029893924783,
0.8351136964569011, 0.759142496847415, 0.7497267759562841,
0.8869130313976105, 0.9369612979550449, 4],
['-1b', 0.7805383022774327, 0.7021791767554478,
0.5686274509803921, 1.0, 0.6013824884792627, 0.7891156462585033,
0.0, 0.07142857142857142, 0.0, 0.11196911196911197,
0.07397959183673469, 2]], dtype=object)
cols = pd.MultiIndex.from_tuples([( 'cat', ''),
(('frac_norm', 2, 1), -5),
(('frac_norm', 2, 1), -4),
(('frac_norm', 2, 1), -3),
(('frac_norm', 2, 1), -2),
(('frac_norm', 2, 1), -1),
(('frac_norm', 2, 1), 0),
(('frac_norm', 2, 1), 1),
(('frac_norm', 2, 1), 2),
(('frac_norm', 2, 1), 3),
(('frac_norm', 2, 1), 4),
(('frac_norm', 2, 1), 5),
( 'cluster', '')],
names=[None, 'month_rel'])
idx = pd.MultiIndex.from_tuples([(3414845, 4232621),
(4369232, 3370279),
( 22771, 3795428),
(2660226, 3126314),
(4154510, 4348009),
(2860801, 164553),
( 120577, 3480468),
(6692132, 5095003),
(2515737, 4263756),
( 707596, 2856619)],
names=['user1', 'user2'])
df = pd.DataFrame(vals, columns=cols, index=idx)


