4

我有两个 numpy 布尔数组(ab)。我需要找出它们中有多少元素是相等的。目前,我这样做len(a) - (a ^ b).sum()了,但据我所知,xor 操作会创建一个全新的 numpy 数组。如何在不创建不必要的临时数组的情况下有效地实现这种期望的行为?

我试过使用 numexpr,但我不能让它正常工作。它不支持 True 为 1 而 False 为 0 的概念,所以我必须使用ne.evaluate("sum(where(a==b, 1, 0))"),这大约需要两倍的时间。

编辑:我忘了提到其中一个数组实际上是另一个大小不同的数组的视图,两个数组都应该被认为是不可变的。两个数组都是二维的,大小往往在 25x40 左右。

是的,这是我程序的瓶颈,值得优化。

4

4 回答 4

2

在我的机器上,这更快:

(a == b).sum()

如果您不想使用任何额外的存储空间,我建议您使用 numba。我不太熟悉它,但这似乎运作良好。我在让 Cython 获取布尔 NumPy 数组时遇到了一些麻烦。

from numba import autojit
def pysumeq(a, b):
    tot = 0
    for i in xrange(a.shape[0]):
        for j in xrange(a.shape[1]):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
# make numba version
nbsumeq = autojit(pysumeq)
A = (rand(10,10)<.5)
B = (rand(10,10)<.5)
# do a simple dry run to get it to compile
# for this specific use case
nbsumeq(A, B)

如果您没有 numba,我建议您使用@user2357112 的答案

编辑:刚刚有一个 Cython 版本工作,这里是.pyx文件。我会和这个一起去的。

from numpy cimport ndarray as ar
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def cysumeq(ar[np.uint8_t,ndim=2,cast=True] a, ar[np.uint8_t,ndim=2,cast=True] b):
    cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
    for i in xrange(h):
        for j in xrange(w):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
于 2013-07-31T04:08:12.290 回答
1

首先,您可以跳过 A*B 步骤:

>>> a
array([ True, False,  True, False,  True], dtype=bool)
>>> b
array([False,  True,  True, False,  True], dtype=bool)
>>> np.sum(~(a^b))
3

如果您不介意破坏数组 a 或 b,我不确定您是否会变得更快:

>>> a^=b   #In place xor operator
>>> np.sum(~a)
3
于 2013-07-31T03:45:26.820 回答
1

如果问题是分配和释放,请维护一个输出数组并告诉 numpy 每次都将结果放在那里:

out = np.empty_like(a) # Allocate this outside a loop and use it every iteration
num_eq = np.equal(a, b, out).sum()

不过,这仅在输入始终具有相同尺寸的情况下才有效。如果输入的大小不同,您也许可以制作一个大数组并切出每次调用所需大小的部分,但我不确定这会减慢您的速度。

于 2013-07-31T03:51:07.550 回答
0

改进 IanH 的答案,还可以通过提供mode="c"给 ndarray.

from numpy cimport ndarray as ar
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef int cy_sum_eq(ar[np.uint8_t,ndim=2,cast=True,mode="c"] a, ar[np.uint8_t,ndim=2,cast=True,mode="c"] b):
    cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
    cdef np.uint8_t* adata = &a[0, 0]
    cdef np.uint8_t* bdata = &b[0, 0]
    for i in xrange(h):
        for j in xrange(w):
            if adata[j] == bdata[j]:
                tot += 1
        adata += w
        bdata += w
    return tot

这在我的机器上比 IanH 的 Cython 版本快了大约 40%,而且我发现在这一点上重新排列循环内容似乎没有太大的区别,这可能是由于编译器优化。此时,可以潜在地链接到使用 SSE 优化的 C 函数等来执行此操作并传递adatabdata作为uint8_t*s

于 2013-08-07T03:10:57.140 回答