3

给定一个 numpy 数组列表,每个数组都具有相同的维度,我如何才能找到哪个数组包含逐个元素的最大值?

例如

import numpy as np
def find_index_where_max_occurs(my_list):
    # d = ...  something goes here ...
    return d

a=np.array([1,1,3,1])
b=np.array([3,1,1,1])
c=np.array([1,3,1,1])

my_list=[a,b,c]

array_of_indices_where_max_occurs = find_index_where_max_occurs(my_list)

# This is what I want:
# >>> print array_of_indices_where_max_occurs
# array([1,2,0,0])
# i.e. for the first element, the maximum value occurs in array b which is at index 1 in my_list.

任何帮助将不胜感激......谢谢!

4

3 回答 3

3

如果你想要一个数组,另一个选择:

>>> np.array((a, b, c)).argmax(axis=0)
array([1, 2, 0, 0])

所以:

def f(my_list):
    return np.array(my_list).argmax(axis=0)

这也适用于多维数组。

于 2013-04-25T14:50:20.020 回答
2

为了好玩,我意识到@Lev的原始答案比他的广义编辑更快,所以这是广义堆叠版本,比np.asarray版本快得多,但不是很优雅。

np.concatenate((a[None,...], b[None,...], c[None,...]), axis=0).argmax(0)

那是:

def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)

一些解释:

我为每个数组添加了一个新轴:arr[None,...]相当于与扩展为适当数字维度arr[np.newaxis,...]的位置相同。这样做的原因是因为它将沿新维度堆叠,这是因为位于前面。arr[np.newaxis,:,:,:]...np.concatenate0None

因此,例如:

In [286]: a
Out[286]: 
array([[0, 1],
       [2, 3]])

In [287]: b
Out[287]: 
array([[10, 11],
       [12, 13]])

In [288]: np.concatenate((a[None,...],b[None,...]),axis=0)
Out[288]: 
array([[[ 0,  1],
        [ 2,  3]],

       [[10, 11],
        [12, 13]]])

如果有助于理解,这也可以:

np.concatenate((a[...,None], b[...,None], c[...,None]), axis=a.ndim).argmax(a.ndim)

现在在末尾添加了新轴,因此我们必须沿最后一个轴堆叠并最大化,这将是a.ndim. 对于abc2d,我们可以这样做:

np.concatenate((a[:,:,None], b[:,:,None], c[:,:,None]), axis=2).argmax(2)

这相当于dstack我在上面的评论中提到的(dstack如果数组中不存在它,则添加第三个轴以堆叠)。

去测试:

N = 10
M = 2

a = np.random.random((N,)*M)
b = np.random.random((N,)*M)
c = np.random.random((N,)*M)

def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)

def byarray(arrs):
    return np.array(arrs).argmax(axis=0)

def byasarray(arrs):
    return np.asarray(arrs).argmax(axis=0)

def bylist(arrs):
    assert arrs[0].ndim == 1, "ndim must be 1"
    return [np.argmax(x) for x in zip(*arrs)]



In [240]: timeit bystack((a,b,c))
100000 loops, best of 3: 18.3 us per loop

In [241]: timeit byarray((a,b,c))
10000 loops, best of 3: 89.7 us per loop

In [242]: timeit byasarray((a,b,c))
10000 loops, best of 3: 90.0 us per loop

In [259]: timeit bylist((a,b,c))
1000 loops, best of 3: 267 us per loop
于 2013-04-25T15:34:20.357 回答
1
[np.argmax(x) for x in zip(*my_list)]

好吧,这只是一个列表,但是如果需要,您知道如何将其设为数组。:)

解释它的作用:zip(*my_list)相当于zip(a,b,c),它为您提供了一个循环生成器。循环中的每个步骤都会给你一个像 一样的元组,循环中的步骤在(a[i], b[i], c[i])哪里i。然后,np.argmax为您提供具有最大值的元素的该元组的索引。

于 2013-04-25T14:48:17.873 回答