1

假设我有来自 Pytorch 或 Keras 预测的概率,结果是 softmax 函数

from scipy.special import softmax
probs = softmax(np.random.randn(20,10),1) # 20 instances and 10 class probabilities
probs

我想从这个 numpy 数组中找到前 5 个索引。我想做的就是在结果上运行一个循环,例如:

for index in top_5_indices:
    if index in result:
        print('Found')

如果我的结果在前 5 个结果中,我会得到。

Pytorchtop-k功能,我已经看到numpy.argpartition了,但我不知道如何完成这项工作?

4

2 回答 2

4

稍微贵一点,但argsort会做:

idx = np.argsort(probs, axis=1)[:,-5:]

如果我们在谈论 pytorch:

probs = torch.from_numpy(softmax(np.random.randn(20,10),1))

values, idx = torch.topk(probs, k=5, axis=-1)
于 2020-11-27T13:33:52.050 回答
2

numpy 中的 argpartition(a, k) 函数将输入数组 a 的索引重新排列在第 k 个最小元素周围,以便较小元素的所有索引都到左边,而较大元素的所有索引都到右边。不需要对所有元素进行排序可以节省时间:argpartition 需要 O(n) 时间,而 argsort 需要 O(n log n) 时间。

因此,您可以获得 5 个最大元素的索引,如下所示:

np.argpartition(probs,-5)[-5:]
于 2020-11-27T13:39:40.373 回答