我有一个多维数组,我需要从最后一维的每一行中获取前 k 个元素。
>>> x = np.random.random_integers(0, 100, size=(2,1,1,5))
>>> x
array([[[[99, 39, 10, 18, 68]]],
[[[22, 3, 13, 56, 2]]]])
我试图得到:
array([[[[ 99., 68.]]],
[[[ 18., 99.]]]])
我可以使用以下方法获取索引,但我不确定如何分割这些值。
>>> k = 2
>>> parts = np.flip(-1 - np.arange(k), 0)
>>> indices = np.flip(
... np.argpartition(x, parts, axis=-1)[..., -k:],
... axis=-1)
>>> indices
array([[[[0, 4]]],
[[[3, 0]]]])