-1

以下代码只是我原始代码的一个示例,如下所示

batch_size = 10
target_q = np.ones((10, 1))
actions = np.ones((10, ), dtype=int)
batch_index = np.arange(batch_size, dtype=np.int32)
print(target_q[batch_index, actions])
print(target_q.shape)

我收到以下错误 IndexError: index 1 is out of bounds for axis 1 with size 1.

有人可以解释这意味着什么以及如何纠正它。

提前致谢。

4

1 回答 1

0

在 numpy 中,您可以索引大小N最大为 index N-1(沿给定轴)的数组,否则您将得到您所看到的 IndexError。为了检查索引可以达到多高,您可以打印target_q.shape. 在您的情况下,它会告诉您(10, 1),这意味着如果您索引target_q[i, j],则i最多可以为 9,j最多可以为 0。您在行中所做的是在第二个位置 ( )target_q[batch_index, actions]上插入所谓的花式索引操作,并且已满的。因此,您尝试多次使用 1 进行索引,而允许的最高索引值为 0。可行的方法是:jactions

import numpy as np

batch_size = 10
target_q = np.ones((10, 1)) 
# changed to zeros below
actions = np.zeros((10, ), dtype=int)
batch_index = np.arange(batch_size, dtype=np.int32)
print(actions)
print(target_q.shape)
print(target_q[batch_index, 0]) 
print(target_q[batch_index, actions])

打印:

[0 0 0 0 0 0 0 0 0 0]
(10, 1)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
于 2020-09-03T09:40:29.397 回答