2

我需要确定矩阵 a 中 k 个最大值的位置(索引)是否与二进制指标矩阵 b 处于相同位置。

import numpy as np
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b

d = argsort(a)
d[:,2:] # Return whether these indices are in 'b'

回报:

a:
[[ 0.8  0.2  0.6  0.4]
 [ 0.9  0.3  0.8  0.6]
 [ 0.2  0.6  0.8  0.4]
 [ 0.3  0.3  0.1  0.8]]
b:
[[1 0 0 1]
 [1 0 1 1]
 [1 1 1 0]
 [1 0 0 1]]

matrix([[2, 0],
        [2, 0],
        [1, 2],
        [1, 3]])

我想比较从最后一个结果返回的索引,如果b在这些位置有索引,则返回计数。对于此示例,最终所需的结果将是:

1
2
2
1

换句话说,在 的第一行中a,前 2 个值仅对应于 中的一个,以此类推b

任何想法如何有效地做到这一点?也许 argsort 在这里是错误的方法。谢谢。

4

3 回答 3

1

当您使用时,argsort您会从 minimum0到 maximum 3,因此您可以反转它[::-1]以获得最大0和最小3

s = np.argsort(a, axis=1)[:,::-1]   
#array([[0, 2, 3, 1],
#       [0, 2, 3, 1],
#       [2, 1, 3, 0],
#       [3, 1, 0, 2]])

现在您可以使用np.take获取0最大值所在的1s 和第二个最大值所在的 s:

s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
s = np.take(s.flatten(),s2)
#array([[0, 3, 1, 2],
#       [0, 3, 1, 2],
#       [3, 1, 0, 2],
#       [2, 1, 3, 0]])

b中,0值应替换为 a np.nan,从而0==np.nan给出False

b = np.float_(b)
b[b==0] = np.nan
#array([[  1.,  nan,  nan,   1.],
#       [  1.,  nan,   1.,   1.],
#       [  1.,   1.,   1.,  nan],
#       [  1.,  nan,  nan,   1.]])

以下比较将为您提供所需的结果:

print np.logical_or(s==b-1, s==b).sum(axis=1)
#[[1]
# [2]
# [2]
# [1]]

一般情况下,将 的n最大值a与二进制文件进行比较b

def check_a_b(a,b,n=2):
    b = np.float_(b)
    b[b==0] = np.nan
    s = np.argsort(a, axis=1)[:,::-1]
    s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
    s = np.take(s.flatten(),s2)
    ans = s==(b-1)
    for i in range(n-1):
        ans = np.logical_or( ans, s==b+i )
    return ans.sum(axis=1)

这将在logical_or.

于 2013-08-08T09:19:34.997 回答
1

另一种更简单、更快的方法,基于以下事实:

True*1=1, True*0=0, False*0=0, and False*1=0

是:

def check_a_b_new(a,b,n=2):
    s = np.argsort(a.view(np.ndarray), axis=1)[:,::-1]
    s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
    s = np.take(s.flatten(),s2)
    return ((s < n)*b.view(np.ndarray)).sum(axis=1)

避免0tonp.nan转换,以及 Pythonfor循环对于高值的n.

于 2013-08-08T15:29:43.490 回答
0

为了响应 Saullo 的巨大帮助,我能够接受他的工作并将解决方案减少到三行。谢谢索洛!

#Inputs
k = 2
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b

# Return values of interest
s = argsort(a.view(np.ndarray), axis=1)[:,::-1]
s2 = s + (arange(s.shape[0])*s.shape[1])[:,None]
out = take(b,s2).view(np.ndarray)[::,:k].sum(axis=1)
print out

给出:

a:
[[ 0.8  0.2  0.6  0.4]
 [ 0.9  0.3  0.8  0.6]
 [ 0.2  0.6  0.8  0.4]
 [ 0.3  0.3  0.1  0.8]]
b:
[[1 0 0 1]
 [1 0 1 1]
 [1 1 1 0]
 [1 0 0 1]]
Out:
[1 2 2 1]
于 2013-08-09T17:39:58.470 回答