在 numpy 中,您可以索引大小N最大为 index N-1(沿给定轴)的数组,否则您将得到您所看到的 IndexError。为了检查索引可以达到多高,您可以打印target_q.shape. 在您的情况下,它会告诉您(10, 1),这意味着如果您索引target_q[i, j],则i最多可以为 9,j最多可以为 0。您在行中所做的是在第二个位置 ( )target_q[batch_index, actions]上插入所谓的花式索引操作,并且已满的。因此,您尝试多次使用 1 进行索引,而允许的最高索引值为 0。可行的方法是:jactions
import numpy as np
batch_size = 10
target_q = np.ones((10, 1))
# changed to zeros below
actions = np.zeros((10, ), dtype=int)
batch_index = np.arange(batch_size, dtype=np.int32)
print(actions)
print(target_q.shape)
print(target_q[batch_index, 0])
print(target_q[batch_index, actions])
打印:
[0 0 0 0 0 0 0 0 0 0]
(10, 1)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]