我正在编写一个脚本,通过估计粒子集合的位移来跟踪样本的变化。第一个实现,在 Python 中,工作正常,但是对于大量样本来说需要很长时间。为了解决这个问题,我尝试在 Cython 中重写该方法,但由于这是我第一次使用它,我似乎无法获得任何性能提升。我知道存在 3D FFT,并且通常比循环的 2D FFT 更快,但在这种情况下,它们占用的内存太多,或者比 for 循环慢。
Python函数:
import numpy as np
from scipy.fft import fftshift
import pyfftw
def python_corr(frame_a, frame_b):
DTYPEf = 'float32'
DTYPEc = 'complex64'
k = frame_a.shape[0]
m = frame_a.shape[1] # size y of 2d sample
n = frame_a.shape[2] # size x of 2d sample
fs = [m,n] # sample shape
bs = [m,n//2+1] # rfft sample shape
corr = np.zeros([k,m,n], DTYPEf) # out
fft_forward = pyfftw.builders.rfft2(
pyfftw.empty_aligned(fs, dtype = DTYPEf),
axes = [-2,-1],
)
fft_backward = pyfftw.builders.irfft2(
pyfftw.empty_aligned(bs, dtype = DTYPEc),
axes = [-2,-1],
)
for ind in range(k): # looping over 2D samples
window_a = frame_a[ind,:,:]
window_b = frame_b[ind,:,:]
corr[ind,:,:] = fftshift( # cross correlation via FFT algorithm
np.real(fft_backward(
np.conj(fft_forward(window_a))*fft_forward(window_b)
)),
axes = [-2,-1]
)
return corr
赛通功能:
import numpy as np
from scipy.fft import fftshift
import pyfftw
cimport numpy as np
np.import_array()
cimport cython
DTYPEf = np.float32
ctypedef np.float32_t DTYPEf_t
DTYPEc = np.complex64
ctypedef np.complex64_t DTYPEc_t
@cython.boundscheck(False)
@cython.nonecheck(False)
def cython_corr(
np.ndarray[DTYPEf_t, ndim = 3] frame_a,
np.ndarray[DTYPEf_t, ndim = 3] frame_b,
):
cdef int ind, k, m, n
k = frame_a.shape[0]
m = frame_a.shape[1] # size y of sample
n = frame_a.shape[2] # size x of sample
cdef DTYPEf_t[:,:] window_a = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample a
window_a[:,:] = 0.
cdef DTYPEf_t[:,:] window_b = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample b
window_b[:,:] = 0.
cdef DTYPEf_t[:,:] corr = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # cross-corr matrix
corr[:,:] = 0.
cdef DTYPEf_t[:,:,:] out = pyfftw.empty_aligned([k,m,n], dtype = DTYPEf) # out
out[:,:] = 0.
cdef object fft_forward
cdef object fft_backward
cdef DTYPEc_t[:,:] f2a = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample a
f2a[:,:] = 0. + 0.j
cdef DTYPEc_t[:,:] f2b = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample b
f2b[:,:] = 0. + 0.j
cdef DTYPEc_t[:,:] r = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # power spectrum of sample a and b
r[:,:] = 0. + 0.j
fft_forward = pyfftw.builders.rfft2(
pyfftw.empty_aligned([m,n], dtype = DTYPEf),
axes = [0,1],
)
fft_backward = pyfftw.builders.irfft2(
pyfftw.empty_aligned([m,n//2+1], dtype = DTYPEc),
axes = [0,1],
)
for ind in range(k):
window_a = frame_a[ind,:,:]
window_b = frame_b[ind,:,:]
r = np.conj(fft_forward(window_a))*fft_forward(window_b) # power spectrum of sample a and b
corr = fft_backward(r).real # cross correlation
corr = fftshift(corr, axes = [0,1]) # shift Q1 --> Q3, Q2 --> Q4
# the fftshift could be moved out of the loop, but lets use that as a last resort :)
out[ind,:,:] = corr
return out
测试方法:
import time
aa = bb = np.empty([14000, 24,24]).astype('float32') # a small test with 14000 24x24px samples
print(f'Number of samples: {aa.shape[0]}')
start = time.time()
corr = python_corr(aa, bb)
print(f'Time for Python: {time.time() - start}')
del corr
start = time.time()
corr = cython_corr(aa, bb)
print(f'Time for Cython: {time.time() - start}')
del corr