在给定多个索引数组的情况下,如何从 NumPy 数组中获取元素并进行广播?或者:我怎样才能简化/矢量化这个循环:
elems = np.random.rand(3, 10, 7) # shape N x I x M
ind = np.array([[1, 2], [3, 4], [0, 9]]) # shape N x J
res = np.stack([elems[i, ind[i]] for i in range(len(elems))]) # shape N x J x M
在给定多个索引数组的情况下,如何从 NumPy 数组中获取元素并进行广播?或者:我怎样才能简化/矢量化这个循环:
elems = np.random.rand(3, 10, 7) # shape N x I x M
ind = np.array([[1, 2], [3, 4], [0, 9]]) # shape N x J
res = np.stack([elems[i, ind[i]] for i in range(len(elems))]) # shape N x J x M
将循环索引转换为范围并使用广播:
>>> elems = np.arange(2*3*4).reshape(2,3,4)
>>> ind = np.arange(0,8,2).reshape(2, 2) % 3
>>>
>>> elems
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> elems[np.arange(2)[:, None], ind]
array([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15]]])