67

Is there a Pythonic and efficient way to check whether a Numpy array contains at least one instance of a given row? By "efficient" I mean it terminates upon finding the first matching row rather than iterating over the entire array even if a result has already been found.

With Python arrays this can be accomplished very cleanly with if row in array:, but this does not work as I would expect for Numpy arrays, as illustrated below.

With Python arrays:

>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False

but Numpy arrays give different and rather odd-looking results. (The __contains__ method of ndarray seems to be undocumented.)

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> np.array([1,2]) in a
True
>>> np.array([1,20]) in a
True
>>> np.array([1,42]) in a
True
>>> np.array([42,1]) in a
False
4

5 回答 5

62

You can use .tolist()

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False

Or use a view:

>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False

Or generate over the numpy list (potentially VERY SLOW):

any(([1,2] == x).all() for x in a)     # stops on first occurrence 

Or use numpy logic functions:

any(np.equal(a,[1,2]).all(1))

If you time these:

import numpy as np
import time

n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()

t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]

tests=[ ('early hit',[t1, t1+1, t1+2]),
        ('middle hit',[t2,t2+1,t2+2]),
        ('late hit', [t3,t3+1,t3+2]),
        ('miss',[0,2,0])]

fmt='\t{:20}{:.5f} seconds and is {}'     

for test, tgt in tests:
    print('\n{}: {} in {:,} elements:'.format(test,tgt,n))

    name='view'
    t1=time.time()
    result=(a[...]==tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='python list'
    t1=time.time()
    result = True if tgt in b else False
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='gen over numpy'
    t1=time.time()
    result=any((tgt == x).all() for x in a)
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='logic equal'
    t1=time.time()
    np.equal(a,tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

You can see that hit or miss, the numpy routines are the same speed to search the array. The Python in operator is potentially a lot faster for an early hit, and the generator is just bad news if you have to go all the way through the array.

Here are the results for 300,000 x 3 element array:

early hit: [9000, 9001, 9002] in 300,000 elements:
    view                0.01002 seconds and is True
    python list         0.00305 seconds and is True
    gen over numpy      0.06470 seconds and is True
    logic equal         0.00909 seconds and is True

middle hit: [450000, 450001, 450002] in 300,000 elements:
    view                0.00915 seconds and is True
    python list         0.15458 seconds and is True
    gen over numpy      3.24386 seconds and is True
    logic equal         0.00937 seconds and is True

late hit: [899970, 899971, 899972] in 300,000 elements:
    view                0.00936 seconds and is True
    python list         0.30604 seconds and is True
    gen over numpy      6.47660 seconds and is True
    logic equal         0.00965 seconds and is True

miss: [0, 2, 0] in 300,000 elements:
    view                0.00936 seconds and is False
    python list         0.01287 seconds and is False
    gen over numpy      6.49190 seconds and is False
    logic equal         0.00965 seconds and is False

And for 3,000,000 x 3 array:

early hit: [90000, 90001, 90002] in 3,000,000 elements:
    view                0.10128 seconds and is True
    python list         0.02982 seconds and is True
    gen over numpy      0.66057 seconds and is True
    logic equal         0.09128 seconds and is True

middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
    view                0.09331 seconds and is True
    python list         1.48180 seconds and is True
    gen over numpy      32.69874 seconds and is True
    logic equal         0.09438 seconds and is True

late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
    view                0.09868 seconds and is True
    python list         3.01236 seconds and is True
    gen over numpy      65.15087 seconds and is True
    logic equal         0.09591 seconds and is True

miss: [0, 2, 0] in 3,000,000 elements:
    view                0.09588 seconds and is False
    python list         0.12904 seconds and is False
    gen over numpy      64.46789 seconds and is False
    logic equal         0.09671 seconds and is False

Which seems to indicate that np.equal is the fastest pure numpy way to do this...

于 2013-02-08T06:18:12.377 回答
22

Numpys __contains__ is, at the time of writing this, (a == b).any() which is arguably only correct if b is a scalar (it is a bit hairy, but I believe – works like this only in 1.7. or later – this would be the right general method (a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any(), which makes sense for all combinations of a and b dimensionality)...

EDIT: Just to be clear, this is not necessarily the expected result when broadcasting is involved. Also someone might argue that it should handle the items in a separately as np.in1d does. I am not sure there is one clear way it should work.

Now you want numpy to stop when it finds the first occurrence. This AFAIK does not exist at this time. It is difficult because numpy is based mostly on ufuncs, which do the same thing over the whole array. Numpy does optimize these kind of reductions, but effectively that only works when the array being reduced is already a boolean array (i.e. np.ones(10, dtype=bool).any()).

Otherwise it would need a special function for __contains__ which does not exist. That may seem odd, but you have to remember that numpy supports many data types and has a bigger machinery to select the correct ones and select the correct function to work on it. So in other words, the ufunc machinery cannot do it, and implementing __contains__ or such specially is not actually that trivial because of data types.

You can of course write it in python, or since you probably know your data type, writing it yourself in Cython/C is very simple.


That said. Often it is much better anyway to use sorting based approach for these things. That is a little tedious as well as there is no such thing as searchsorted for a lexsort, but it works (you could also abuse scipy.spatial.cKDTree if you like). This assumes you want to compare along the last axis only:

# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()

# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.

sorted.sort()

b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)

result = sorted[ind] == b_comp

This works also for an array b, and if you keep the sorted array around, is also much better if you do it for a single value (row) in b at a time, when a stays the same (otherwise I would just np.in1d after viewing it as a recarray). Important: you must do the np.ascontiguousarray for safety. It will typically do nothing, but if it does, it would be a big potential bug otherwise.

于 2013-02-08T12:12:57.300 回答
9

I think

equal([1,2], a).all(axis=1)   # also,  ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)

will list the rows that match. As Jamie points out, to know whether at least one such row exists, use any:

equal([1,2], a).all(axis=1).any()
# True

Aside:
I suspect in (and __contains__) is just as above but using any instead of all.

于 2013-02-08T05:30:05.910 回答
2

I've compared the suggested solutions with perfplot and found that, if you're looking for a 2-tuple in a long unsorted list,

np.any(np.all(a == b, axis=1))

is the fastest solution. An explicit short-circuit loop can always be faster if a match is found in the first few rows.

enter image description here

Code to reproduce the plot:

import numpy as np
import perfplot


target = [6, 23]


def setup(n):
    return np.random.randint(0, 100, (n, 2))


def any_all(data):
    return np.any(np.all(target == data, axis=1))


def tolist(data):
    return target in data.tolist()

def loop(data):
    for row in data:
        if np.all(row == target):
            return True
    return False


def searchsorted(a):
    s = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
    s.sort()
    t = np.ascontiguousarray(target).view(s.dtype)
    ind = s.searchsorted(t)
    return (s[ind] == t)[0]


perfplot.save(
    "out02.png",
    setup=setup,
    kernels=[any_all, tolist, loop, searchsorted],
    n_range=[2 ** k for k in range(2, 20)],
    xlabel="len(array)",
)
于 2021-02-26T19:41:30.347 回答
1

If you really want to stop at the first occurrence, you could write a loop, like:

import numpy as np

needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
    if np.all(row == needle):
        found = True
        break
print("Found: ", found)

However, I strongly suspect, that it will be much slower than the other suggestions which use numpy routines to do it for the whole array.

于 2013-02-08T10:01:43.947 回答