1

假设我有一个大小为batchx max_lenx的数组output_size,其中batch,max_lenoutput_size都对应于正自然数。我有一个索引列表,这些索引对应于维度 1(即max_len)中的各个项目。给定这些索引,我如何从数组中进行选择?

作为一个具体的例子,假设我有以下内容:

>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]

当我选择l给定时,idx我得到:

>>> l[:,idx,:].shape
(4, 4, 6)
>>>

我也尝试过np.take但达到了相同的结果:

>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>> 

但是,我正在关注的输出是(4,1,6)因为我试图让一个项目查看batch(即第一维)中的每个元素。我怎样才能产生具有正确形状的输出?

4

1 回答 1

2

np.take_along_axis在扩展后使用以具有与-idx相同的 ndiml

np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)

使用显式整数数组索引 -

l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p
于 2019-12-01T09:04:04.840 回答