例如,我有两个 ndarray,形状为train_dataset
is(10000, 28, 28)
和形状为val_dateset
is (2000, 28, 28)
。
除了使用迭代之外,有没有任何有效的方法可以使用 numpy 数组函数来查找两个 ndarray 之间的重叠?
例如,我有两个 ndarray,形状为train_dataset
is(10000, 28, 28)
和形状为val_dateset
is (2000, 28, 28)
。
除了使用迭代之外,有没有任何有效的方法可以使用 numpy 数组函数来查找两个 ndarray 之间的重叠?
我从Jaime 的出色答案中学到的一个技巧是使用np.void
dtype 将输入数组中的每一行视为单个元素。这使您可以将它们视为一维数组,然后可以将其传递给np.in1d
其他设置例程之一。
import numpy as np
def find_overlap(A, B):
if not A.dtype == B.dtype:
raise TypeError("A and B must have the same dtype")
if not A.shape[1:] == B.shape[1:]:
raise ValueError("the shapes of A and B must be identical apart from "
"the row dimension")
# reshape A and B to 2D arrays. force a copy if neccessary in order to
# ensure that they are C-contiguous.
A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
B = np.ascontiguousarray(B.reshape(B.shape[0], -1))
# void type that views each row in A and B as a single item
t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))
# use in1d to find rows in A that are also in B
return np.in1d(A.view(t), B.view(t))
例如:
gen = np.random.RandomState(0)
A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]
A_in_B = find_overlap(A, B)
print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True
这种方法比 Divakar 的内存效率高得多,因为它不需要广播到(m, n, ...)
布尔数组。事实上,如果A
和B
是行优先的,则根本不需要复制。
为了比较,我稍微调整了 Divakar 和 BM 的解决方案。
def divakar(A, B):
A.shape = A.shape[0], -1
B.shape = B.shape[0], -1
return (B[:,None] == A).all(axis=(2)).any(0)
def bm(A, B):
t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
ma = np.frombuffer(np.ascontiguousarray(A), t)
mb = np.frombuffer(np.ascontiguousarray(B), t)
return (mb[:, None] == ma).any(0)
In [1]: na = 1000; nb = 200; rowshape = 28, 28
In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
....:
1 loops, best of 3: 244 ms per loop
In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
100 loops, best of 3: 2.81 ms per loop
In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
100 loops, best of 3: 15 ms per loop
如您所见,BM 的解决方案比我的小n解决方案稍快,但np.in1d
比测试所有元素的相等性(O(n log n)而不是O(n²)复杂度)具有更好的扩展性。
In [5]: na = 10000; nb = 2000; rowshape = 28, 28
In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
1 loops, best of 3: 271 ms per loop
In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
10 loops, best of 3: 123 ms per loop
对于这种大小的阵列,Divakar 的解决方案在我的笔记本电脑上是难以处理的,因为它需要生成一个 15GB 的中间阵列,而我只有 8GB 的 RAM。
全广播在这里生成一个 10000*2000*28*28 =150 Mo 布尔数组。
为了提高效率,您可以:
打包数据,对于 200 ko 数组:
from pylab import *
N=10000
a=rand(N,28,28)
b=a[[randint(0,N,N//5)]]
packedtype='S'+ str(a.size//a.shape[0]*a.dtype.itemsize) # 'S6272'
ma=frombuffer(a,packedtype) # ma.shape=10000
mb=frombuffer(b,packedtype) # mb.shape=2000
%timeit a[:,None]==b : 102 s
%timeit ma[:,None]==mb : 800 ms
allclose((a[:,None]==b).all((2,3)),(ma[:,None]==mb)) : True
惰性字符串比较有助于减少内存,打破第一个差异:
In [31]: %timeit a[:100]==b[:100]
10000 loops, best of 3: 175 µs per loop
In [32]: %timeit a[:100]==a[:100]
10000 loops, best of 3: 133 µs per loop
In [34]: %timeit ma[:100]==mb[:100]
100000 loops, best of 3: 7.55 µs per loop
In [35]: %timeit ma[:100]==ma[:100]
10000 loops, best of 3: 156 µs per loop
这里给出了解决方案(ma[:,None]==mb).nonzero().
使用in1d
,对于(Na+Nb) ln(Na+Nb)
复杂性,反对
Na*Nb
完全比较:
%timeit in1d(ma,mb).nonzero() : 590ms
这里不是很大的收获,但渐近更好。
内存允许你可以使用broadcasting
,像这样-
val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
样品运行 -
In [55]: train_dataset
Out[55]:
array([[[1, 1],
[1, 1]],
[[1, 0],
[0, 0]],
[[0, 0],
[0, 1]],
[[0, 1],
[0, 0]],
[[1, 1],
[1, 0]]])
In [56]: val_dateset
Out[56]:
array([[[0, 1],
[1, 0]],
[[1, 1],
[1, 1]],
[[0, 0],
[0, 1]]])
In [57]: val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
Out[57]:
array([[[1, 1],
[1, 1]],
[[0, 0],
[0, 1]]])
如果元素是整数,您可以将输入数组中的每个块折叠axis=(1,2)
成一个标量,假设它们是可线性索引的数字,然后有效地使用np.in1d
或np.intersect1d
查找匹配项。
这个问题来自谷歌的在线深度学习课程?以下是我的解决方案:
sum = 0 # number of overlapping rows
for i in range(val_dataset.shape[0]): # iterate over all rows of val_dataset
overlap = (train_dataset == val_dataset[i,:,:]).all(axis=1).all(axis=1).sum()
if overlap:
sum += 1
print(sum)
使用自动广播代替迭代。您可以测试性能差异。
def overlap(a,b):
"""
returns a boolean index array for input array b representing
elements in b that are also found in a
"""
a.repeat(b.shape[0],axis=0)
b.repeat(a.shape[0],axis=0)
c = aa == bb
c = c[::a.shape[0]]
return c.all(axis=1)[:,0]
您可以使用返回的索引数组来索引b
以提取也可以在a
b[overlap(a,b)]
为简单起见,我假设您已经numpy
为此示例导入了所有内容:
from numpy import *
因此,例如,给定两个 ndarray
a = arange(4*2*2).reshape(4,2,2)
b = arange(3*2*2).reshape(3,2,2)
我们重复a
,b
使它们具有相同的形状
aa = a.repeat(b.shape[0],axis=0)
bb = b.repeat(a.shape[0],axis=0)
然后我们可以简单地比较 和 的aa
元素bb
c = aa == bb
最后,b
通过a
查看每 4 个,或者实际上,shape(a)[0]
每个c
cc == c[::a.shape[0]]
最后,我们提取一个索引数组,其中仅包含子数组中所有元素所在的元素True
c.all(axis=1)[:,0]
在我们的例子中,我们得到
array([True, True, True], dtype=bool)
要检查,请更改第一个元素b
b[0] = array([[50,60],[70,80]])
我们得到
array([False, True, True], dtype=bool)