我正在使用 pyFFTW 对 2D 复数数组进行 2D FFT。这些数组可能会变得非常大(~128 GiB),因此执行时间至关重要。(背景是光学物理学中的波前传播。)
看看下面的玩具代码:
import numpy as np
import pyfftw
import multiprocessing
a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
在我的现代 64 位机器上执行 FFT 需要几秒钟。
分两步执行 2D FFT(所有列和所有行的 1D-FFT)时,结果和执行时间都保持不变:
fft = pyfftw.FFTW(a, a, axes = (0,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
fft = pyfftw.FFTW(a, a, axes = (1,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
但是,单独计算这些步骤的时间表明column-FFT 比 row-FFT 慢大约 10 倍。
我想,原因是数组被逐行保存到物理 RAM 中。确实, a.flags 给出了
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
而 a.strides 给出
(262144, 16)
因此,该数组是 C 连续的,并且似乎正确对齐。但是,删除标志 'FFTW_UNALIGNED' 会使列 FFT 大约再慢 10 倍(而行 FFT 变得稍微快一些)。
因此,我的问题是:
对齐是否有问题,或者对于 C 连续数组的物理限制,对列的访问比对行的访问慢 10 倍?
编辑:确实,10 倍似乎太大了。让我们比较一下行和列的简单读/写访问:
a[:,0:16384:2]*=1j
和
a[0:16384:2,:]*=1j
将列与偶数索引相乘(第一个变体)比将行与偶数索引(第二个变体)相乘慢大约 2 倍。
编辑:在 ipython 中输入的确切代码是
In [1]: import pyfftw
In [2]: import multiprocessing
In [3]: a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
In [4]: fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
In [5]: %timeit a = fft()