0

我正在编写一个脚本,通过估计粒子集合的位移来跟踪样本的变化。第一个实现,在 Python 中,工作正常,但是对于大量样本来说需要很长时间。为了解决这个问题,我尝试在 Cython 中重写该方法,但由于这是我第一次使用它,我似乎无法获得任何性能提升。我知道存在 3D FFT,并且通常比循环的 2D FFT 更快,但在这种情况下,它们占用的内存太多,或者比 for 循环慢。

Python函数:

import numpy as np
from scipy.fft import fftshift
import pyfftw
def python_corr(frame_a, frame_b):
    DTYPEf = 'float32'
    DTYPEc = 'complex64'
    
    k = frame_a.shape[0]
    m = frame_a.shape[1] # size y of 2d sample
    n = frame_a.shape[2] # size x of 2d sample
    
    fs = [m,n] # sample shape
    bs = [m,n//2+1] # rfft sample shape
    
    corr = np.zeros([k,m,n], DTYPEf) # out
    
    fft_forward = pyfftw.builders.rfft2(
        pyfftw.empty_aligned(fs, dtype = DTYPEf),
        axes = [-2,-1],
    )
    
    fft_backward = pyfftw.builders.irfft2(
        pyfftw.empty_aligned(bs, dtype = DTYPEc),
        axes = [-2,-1],
    )
    for ind in range(k): # looping over 2D samples
        window_a = frame_a[ind,:,:]
        window_b = frame_b[ind,:,:]
        corr[ind,:,:] = fftshift( # cross correlation via FFT algorithm
            np.real(fft_backward(
                np.conj(fft_forward(window_a))*fft_forward(window_b)
            )),
            axes = [-2,-1]
        )
    return corr

赛通功能:

import numpy as np
from scipy.fft import fftshift
import pyfftw

cimport numpy as np
np.import_array()
cimport cython

DTYPEf = np.float32
ctypedef np.float32_t DTYPEf_t
DTYPEc = np.complex64
ctypedef np.complex64_t DTYPEc_t

@cython.boundscheck(False)
@cython.nonecheck(False)
def cython_corr(
    np.ndarray[DTYPEf_t, ndim = 3] frame_a, 
    np.ndarray[DTYPEf_t, ndim = 3] frame_b,
):
    cdef int ind, k, m, n
    
    k = frame_a.shape[0]
    m = frame_a.shape[1] # size y of sample
    n = frame_a.shape[2] # size x of sample
    
    cdef DTYPEf_t[:,:] window_a = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample a
    window_a[:,:] = 0.
    
    cdef DTYPEf_t[:,:] window_b = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample b
    window_b[:,:] = 0.
    
    cdef DTYPEf_t[:,:] corr = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # cross-corr matrix
    corr[:,:] = 0.
    
    cdef DTYPEf_t[:,:,:] out = pyfftw.empty_aligned([k,m,n], dtype = DTYPEf) # out
    out[:,:] = 0.
        
    cdef object fft_forward
    cdef object fft_backward
    
    cdef DTYPEc_t[:,:] f2a = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample a
    f2a[:,:] = 0. + 0.j
    
    cdef DTYPEc_t[:,:] f2b = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample b
    f2b[:,:] = 0. + 0.j
    
    cdef DTYPEc_t[:,:] r = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # power spectrum of sample a and b
    r[:,:] = 0. + 0.j
    
    fft_forward = pyfftw.builders.rfft2(
        pyfftw.empty_aligned([m,n], dtype = DTYPEf),
        axes = [0,1],
    )
    fft_backward = pyfftw.builders.irfft2(
        pyfftw.empty_aligned([m,n//2+1], dtype = DTYPEc),
        axes = [0,1],
    )
    for ind in range(k):
        window_a = frame_a[ind,:,:]
        window_b = frame_b[ind,:,:]
        r = np.conj(fft_forward(window_a))*fft_forward(window_b) # power spectrum of sample a and b
        corr = fft_backward(r).real # cross correlation
        corr = fftshift(corr, axes = [0,1]) # shift Q1 --> Q3, Q2 --> Q4
        # the fftshift could be moved out of the loop, but lets use that as a last resort :)
        out[ind,:,:] = corr
    return out

测试方法:

import time
aa = bb = np.empty([14000, 24,24]).astype('float32') # a small test with 14000 24x24px samples
print(f'Number of samples: {aa.shape[0]}')

start = time.time()
corr = python_corr(aa, bb)
print(f'Time for Python: {time.time() - start}')
del corr

start = time.time()
corr = cython_corr(aa, bb)
print(f'Time for Cython: {time.time() - start}')
del corr
4

0 回答 0