0

假设我有以下 DataFrame Q_df

        (0, 0)  (0, 1)  (0, 2)  (1, 0)  (1, 1)  (1, 2)  (2, 0)  (2, 1)  (2, 2)
(0, 0)   0.000    0.00     0.0    0.64   0.000     0.0   0.512   0.000     0.0
(0, 1)   0.000    0.00     0.8    0.00   0.512     0.0   0.000   0.512     0.0
(0, 2)   0.000    0.64     0.0    0.00   0.000     0.8   0.000   0.000     1.0
(1, 0)   0.512    0.00     0.0    0.00   0.000     0.8   0.512   0.000     0.0
(1, 1)   0.000    0.64     0.0    0.00   0.000     0.0   0.000   0.512     0.0
(1, 2)   0.000    0.00     0.8    0.64   0.000     0.0   0.000   0.000     1.0
(2, 0)   0.512    0.00     0.0    0.64   0.000     0.0   0.000   0.512     0.0
(2, 1)   0.000    0.64     0.0    0.00   0.512     0.0   0.512   0.000     0.0
(2, 2)   0.000    0.00     0.8    0.00   0.000     0.8   0.000   0.000     0.0

这是使用以下代码生成的:

import numpy as np
import pandas as pd

states = list(itertools.product(range(3), repeat=2))

Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])

Q_df = pd.DataFrame(index=states, columns=states, data=Q)

对于Q的每一行,我想获取对应于该行最大值的列名。如果我尝试

policy = Q_df.idxmax()

然后生成的系列看起来像这样:

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (0, 1)
(1, 0)    (0, 0)
(1, 1)    (0, 1)
(1, 2)    (0, 2)
(2, 0)    (0, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)

第一行看起来不错:第一行的最大元素是0.64并且出现在 column 中(1,0)。第二个也是如此。然而,对于第三行,最大元素是0.8并且出现在 column 中(1,2),所以我希望 in 中的对应值policy(1,2),而不是(0,1)

知道这里出了什么问题吗?

4

1 回答 1

2

IIUC,您可以axis=1idxmax

policy = Q_df.idxmax(axis=1)

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (2, 2)
(1, 0)    (1, 2)
(1, 1)    (0, 1)
(1, 2)    (2, 2)
(2, 0)    (1, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)
dtype: object
于 2016-07-30T07:29:07.167 回答