as_strided
结合 numpy 的广播功能,你可以做一些令人惊奇的事情。这是您的函数的两个版本:
import numpy as np
from numpy.lib.stride_tricks import as_strided
def sumsqdiff(input_image, template, valid_mask=None):
if valid_mask is None:
valid_mask = np.ones_like(template)
total_weight = valid_mask.sum()
window_size = template.shape
ssd = np.empty((input_image.shape[0] - window_size[0] + 1,
input_image.shape[1] - window_size[1] + 1))
for i in xrange(ssd.shape[0]):
for j in xrange(ssd.shape[1]):
sample = input_image[i:i + window_size[0], j:j + window_size[1]]
dist = (template - sample) ** 2
ssd[i, j] = (dist * valid_mask).sum()
return ssd
def sumsqdiff2(input_image, template, valid_mask=None):
if valid_mask is None:
valid_mask = np.ones_like(template)
total_weight = valid_mask.sum()
window_size = template.shape
# Create a 4-D array y, such that y[i,j,:,:] is the 2-D window
# input_image[i:i+window_size[0], j:j+window_size[1]]
y = as_strided(input_image,
shape=(input_image.shape[0] - window_size[0] + 1,
input_image.shape[1] - window_size[1] + 1,) +
window_size,
strides=input_image.strides * 2)
# Compute the sum of squared differences using broadcasting.
ssd = ((y - template) ** 2 * valid_mask).sum(axis=-1).sum(axis=-1)
return ssd
这是一个比较它们的 ipython 会话。
我将用于演示的模板:
In [72]: template
Out[72]:
array([[-1, 1, -1],
[ 1, 2, 1],
[-1, 1, -1]])
一个小的输入,所以我们可以检查结果:
In [73]: x
Out[73]:
array([[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[ 14., 15., 16., 17., 18., 19., 20.],
[ 21., 22., 23., 24., 25., 26., 27.],
[ 28., 29., 30., 31., 32., 33., 34.]])
应用这两个函数x
并检查我们是否得到相同的结果:
In [74]: sumsqdiff(x, template)
Out[74]:
array([[ 856., 1005., 1172., 1357., 1560.],
[ 2277., 2552., 2845., 3156., 3485.],
[ 4580., 4981., 5400., 5837., 6292.]])
In [75]: sumsqdiff2(x, template)
Out[75]:
array([[ 856., 1005., 1172., 1357., 1560.],
[ 2277., 2552., 2845., 3156., 3485.],
[ 4580., 4981., 5400., 5837., 6292.]])
现在制作一个更大的输入“图像”:
In [76]: z = np.random.randn(500, 500)
并检查性能:
In [77]: %timeit sumsqdiff(z, template)
1 loops, best of 3: 3.55 s per loop
In [78]: %timeit sumsqdiff2(z, template)
10 loops, best of 3: 33 ms per loop
不是太寒酸。:)
两个缺点:
- 中的计算
sumsqdiff2
将生成一个临时数组,对于 3x3 模板,该数组的大小将是input_image
. (通常它template.size
的大小是 的倍数input_image
。)
- 当您对代码进行 Cythonize 时,这些“跨步技巧”对您没有帮助。转换为 Cython 时,您通常最终会回到使用 numpy 进行矢量化时摆脱的循环。