3

假设我有一个长度为 4 的数组的 numpy 数组:

In [41]: arr
Out[41]:
array([[  1,  15,   0,   0],
       [ 30,  10,   0,   0],
       [ 30,  20,   0,   0],
       ...,
       [104, 139, 146,  75],
       [  9,  11, 146,  74],
       [  9, 138, 146,  75]], dtype=uint8)

我想知道:

  1. 是真的arr[1, 2, 3, 4]
  2. 如果是真的 in 的索引是[1, 2, 3, 4]多少arr

我想尽快找到它。

假设arr包含 8550420 个元素。我检查了几种方法timeit

  1. 仅用于检查而不获取 index: any(all([1, 2, 3, 4] == elt) for elt in arr)。在我的机器上运行 10 次平均需要 15.5 秒
  2. for- 基于解决方案:

    for i,e in enumerate(arr): if list(e) == [1, 2, 3, 4]: break

    平均耗时约 5.7 秒

是否存在一些更快的解决方案,例如基于 numpy 的解决方案?

4

2 回答 2

6

这是Jaime 的想法,我就是喜欢它:

import numpy as np

def asvoid(arr):
    """View the array as dtype np.void (bytes)
    This collapses ND-arrays to 1D-arrays, so you can perform 1D operations on them.
    https://stackoverflow.com/a/16216866/190597 (Jaime)"""    
    arr = np.ascontiguousarray(arr)
    return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))

def find_index(arr, x):
    arr_as1d = asvoid(arr)
    x = asvoid(x)
    return np.nonzero(arr_as1d == x)[0]


arr = np.array([[  1,  15,   0,   0],
                [ 30,  10,   0,   0],
                [ 30,  20,   0,   0],
                [1, 2, 3, 4],
                [104, 139, 146,  75],
                [  9,  11, 146,  74],
                [  9, 138, 146,  75]], dtype='uint8')

arr = np.tile(arr,(1221488,1))
x = np.array([1,2,3,4], dtype='uint8')

print(find_index(arr, x))

产量

[      3      10      17 ..., 8550398 8550405 8550412]

这个想法是将数组的每一视为一个字符串。例如,

In [15]: x
Out[15]: 
array([^A^B^C^D], 
      dtype='|V4')

字符串看起来像垃圾,但它们实际上只是每一行中的底层数据,被视为字节。然后您可以比较arr_as1d == x找出哪些相等x


还有另一种方法

def find_index2(arr, x):
    return np.where((arr == x).all(axis=1))[0]

但事实证明并没有那么快:

In [34]: %timeit find_index(arr, x)
1 loops, best of 3: 209 ms per loop

In [35]: %timeit find_index2(arr, x)
1 loops, best of 3: 370 ms per loop
于 2013-07-23T13:36:45.283 回答
0

如果您执行搜索不止一次并且您不介意使用额外的内存,您可以从您的数组创建集合(我在这里使用列表,但它几乎是相同的代码):

>>> elem = [1, 2, 3, 4]    
>>> elements = [[  1,  15,   0,   0], [ 30,  10,   0,   0], [1, 2, 3, 4]]
>>> index = set([tuple(x) for x in elements])
>>> True if tuple(elem) in index else False
True
于 2013-07-23T13:19:40.783 回答