我有一个场景的范围图像。我遍历图像并计算检测窗口下的平均深度变化。检测窗口根据当前位置周围像素的平均深度改变大小。我累积平均变化以产生一个简单的响应图像。
大部分时间都花在 for 循环中,在我的机器上拍摄 512x52 图像大约需要 40 多秒。我希望能加快一些速度。是否有更有效/更快的方式来遍历图像?是否有更好的 pythonic/numpy/scipy 方式来访问每个像素?还是我应该去学习cython?
编辑:我通过使用 scipy.misc.imread() 而不是 skimage.io.imread() 将运行时间减少到大约 18 秒。不知道有什么区别,我会尝试调查。
这是代码的简化版本:
import matplotlib.pylab as plt
import numpy as np
from skimage.io import imread
from skimage.transform import integral_image, integrate
import time
def intersect(a, b):
'''Determine the intersection of two rectangles'''
rect = (0,0,0,0)
r0 = max(a[0],b[0])
c0 = max(a[1],b[1])
r1 = min(a[2],b[2])
c1 = min(a[3],b[3])
# Do we have a valid intersection?
if r1 > r0 and c1 > c0:
rect = (r0,c0,r1,c1)
return rect
# Setup data
depth_src = imread("test.jpg", as_grey=True)
depth_intg = integral_image(depth_src) # integrate to find sum depth in region
depth_pts = integral_image(depth_src > 0) # integrate to find num points which have depth
boundary = (0,0,depth_src.shape[0]-1,depth_src.shape[1]-1) # rectangle to intersect with
# Image to accumulate response
out_img = np.zeros(depth_src.shape)
# Average dimensions of bbox/detection window per unit length of depth
model = (0.602,2.044) # width, height
start_time = time.time()
for (r,c), junk in np.ndenumerate(depth_src):
# Find points around current pixel
r0, c0, r1, c1 = intersect((r-1, c-1, r+1, c+1), boundary)
# Calculate average of depth of points around current pixel
scale = integrate(depth_intg, r0, c0, r1, c1) * 255 / 9.0
# Based on average depth, create the detection window
r0 = r - (model[0] * scale/2)
c0 = c - (model[1] * scale/2)
r1 = r + (model[0] * scale/2)
c1 = c + (model[1] * scale/2)
# Used scale optimised detection window to extract features
r0, c0, r1, c1 = intersect((r0,c0,r1,c1), boundary)
depth_count = integrate(depth_pts,r0,c0,r1,c1)
if depth_count:
depth_sum = integrate(depth_intg,r0,c0,r1,c1)
avg_change = depth_sum / depth_count
# Accumulate response
out_img[r0:r1,c0:c1] += avg_change
print time.time() - start_time, " seconds"
plt.imshow(out_img)
plt.gray()
plt.show()