正如@Murali 在评论中指出的那样,method1
效率不是很高,因为它没有成功使用BLAS调用,而不是使用method2
哪个调用。事实上,由于np.einsum
OpenBLAS (Numpy 在大多数机器上使用),method1
它按顺序计算结果,同时method2
大部分并行运行。话虽如此,这method2
是次优的,因为它没有完全使用可用的内核(部分计算是按顺序完成的)并且似乎没有有效地使用缓存。在我的 6 核机器上,它几乎不使用所有内核的 50%。
更快的实施
加快此计算的一种解决方案是为此编写高度优化的Numba并行代码。
首先,半天真的实现是使用许多 for 循环来计算 Einstein 求和并重塑输入/输出数组,以便 Numba 可以更好地优化代码(例如展开、使用SIMD 指令)。结果如下:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)
for NHW in range(sN*sH*sW):
for g in range(sg):
for o in range(so):
s = 0.0
# Reduction
for ihw in range(si*sh*sw):
s += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW, g, o] = s
return out.reshape((sN, sH, sW, sg, so))
请注意,假设输入数组是连续的。如果不是这种情况,请考虑执行复制(与计算相比便宜)。
虽然上面的代码有效,但它远非高效。以下是一些可以执行的改进:
- 并行运行最外层
NHW
循环;
- 使用 Numba 标志
fastmath=True
。如果输入数据包含特殊值,如 NaN 或 +inf/-inf,则此标志是不安全的。但是,此标志有助于编译器使用 SIMD 指令生成更快的代码(否则这是不可能的,因为 IEEE-754 浮点运算不是关联的);
- 交换
NHW
基于 - 的循环和g
基于 - 的循环会导致更好的性能,因为它提高了缓存局部性(rb
更有可能适合主流 CPU 的最后一级缓存,否则它可能会从 RAM 中获取);
- 利用寄存器阻塞,使处理器的更好的 SIMD 计算单元饱和,减少内存层次的压力;
- 通过拆分基于 - 的循环来使用平铺,因此几乎可以从较低级别的缓存(例如 L1 或 L2)中完全读取。
o
rb
除了最后一项之外,所有这些改进都在以下代码中实现:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)
for g in range(sg):
for k in nb.prange((sN*sH*sW)//2):
NHW = k*2
so_vect_max = (so // 4) * 4
for o in range(0, so_vect_max, 4):
s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0
# Useful since Numba does not optimize well the following loop otherwise
ra_row0 = ra[NHW+0, g, :]
ra_row1 = ra[NHW+1, g, :]
rb_row0 = rb[g, o+0, :]
rb_row1 = rb[g, o+1, :]
rb_row2 = rb[g, o+2, :]
rb_row3 = rb[g, o+3, :]
# Highly-optimized reduction using register blocking
for ihw in range(si*sh*sw):
ra_0 = ra_row0[ihw]
ra_1 = ra_row1[ihw]
rb_0 = rb_row0[ihw]
rb_1 = rb_row1[ihw]
rb_2 = rb_row2[ihw]
rb_3 = rb_row3[ihw]
s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
s12 += ra_1 * rb_2; s13 += ra_1 * rb_3
out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13
# Remaining part for `o`
for o in range(so_vect_max, so):
for ihw in range(si*sh*sw):
out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]
# Remaining part for `k`
if (sN*sH*sW) % 2 == 1:
k = sN*sH*sW - 1
for o in range(so):
for ihw in range(si*sh*sw):
out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]
return out.reshape((sN, sH, sW, sg, so))
这段代码更加复杂和丑陋,但也更加高效。我没有实现平铺优化,因为它会使代码的可读性更差。但是,它应该会在多核处理器(尤其是具有小型 L2/L3 缓存的处理器)上显着加快代码速度。
性能结果
以下是我的 i5-9600KF 6 核处理器的性能结果:
method1: 816 ms
method2: 104 ms
method3: 40 ms
Theoretical optimal: 9 ms (optimistic lower bound)
该代码比method2
. 由于最佳时间比 . 好 4 倍左右,因此还有改进的余地method3
。
Numba 不生成快速代码的主要原因是底层 JIT 未能有效地矢量化循环。实施平铺策略应该会稍微提高执行时间,非常接近最佳时间。平铺策略对于更大的阵列至关重要。如果so
更大,则尤其如此。
如果您想要更快的实现,您当然需要直接使用 SIMD 内在函数(遗憾的是不可移植)或 SIMD 库(例如 XSIMD)编写 C/C++ 本机代码。
如果您想要更快的实现,那么您需要使用更快的硬件(具有更多内核)或更专用的硬件。基于服务器的 GPU(即不是个人计算机)不应该能够加速很多这样的计算,因为您的输入很小,显然受计算限制并且大量使用 FMA 浮点运算。第一个开始是尝试cupy.einsum
。
底层:底层分析
为了理解为什么method1
没有更快,我检查了执行的代码。这是主循环:
1a0:┌─→; Part of the reduction (see below)
│ movapd xmm0,XMMWORD PTR [rdi-0x1000]
│
│ ; Decrement the number of loop cycle
│ sub r9,0x8
│
│ ; Prefetch items so to reduce the impact
│ ; of the latency of reading from the RAM.
│ prefetcht0 BYTE PTR [r8]
│ prefetcht0 BYTE PTR [rdi]
│
│ ; Part of the reduction (see below)
│ mulpd xmm0,XMMWORD PTR [r8-0x1000]
│
│ ; Increment iterator for the two arrays
│ add rdi,0x40
│ add r8,0x40
│
│ ; Main computational part:
│ ; reduction using add+mul SSE2 instructions
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1030]
│ mulpd xmm0,XMMWORD PTR [r8-0x1030]
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1020]
│ mulpd xmm0,XMMWORD PTR [r8-0x1020]
│ addpd xmm0,xmm1 <--- Slow
│ movapd xmm1,XMMWORD PTR [rdi-0x1010]
│ mulpd xmm1,XMMWORD PTR [r8-0x1010]
│ addpd xmm1,xmm0 <--- Slow
│
│ ; Is the loop over?
│ ; If not, jump to the beginning of the loop.
├──cmp r9,0x7
└──jg 1a0
事实证明,Numpy 使用 SSE2 指令集(在所有 x86-64 处理器上都可用)。然而,我的机器,就像几乎所有相对较新的处理器一样,支持 AVX 指令集,每条指令一次可以计算两倍以上的项目。我的机器还支持在这种情况下快两倍的熔丝倍增指令 (FMA)。此外,循环清楚地受到将addpd
结果累积在几乎相同的寄存器中的限制。处理器无法有效地执行它们,因为addpd
需要几个延迟周期,并且在现代 x86-64 处理器上最多可以同时执行两个(这在此处是不可能的,因为一次只有一个指令可以执行累积xmm1
)。
这是method2
(dgemm
OpenBLAS的调用)的主要计算部分的执行代码:
6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi-0x80]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi-0x60]
│ vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi-0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi-0x20]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi+0x20]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi+0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi+0x60]
│ add rsi,0x40
│ add rdi,0x100
├──dec rax
└──jne 6a40
这个循环更加优化:它利用了 AVX 指令集以及 FMA 指令集(即vfmadd231pd
指令)。此外,循环更好地展开,并且没有像 Numpy 代码中那样的延迟/依赖问题。然而,虽然这个循环非常高效,但由于在 Numpy 中进行了一些顺序检查和在 OpenBLAS 中执行顺序复制,内核并没有得到有效使用。此外,我不确定在这种情况下循环是否有效地使用了缓存,因为在我的机器上的 RAM 中执行了大量的读/写操作。实际上,由于许多缓存未命中,RAM 吞吐量约为 15 GiB/s(超过 35~40 GiB/s),而吞吐量method3
为 6 GiB/s(因此在缓存中完成了更多工作),执行速度明显更快。
这是主要计算部分的执行代码method3
:
.LBB0_5:
vorpd 2880(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm2
vmovupd %ymm2, 3040(%rsp)
vorpd 2848(%rsp), %ymm8, %ymm1
vpcmpeqd %ymm2, %ymm2, %ymm2
vgatherqpd %ymm2, (%rsi,%ymm1,8), %ymm3
vmovupd %ymm3, 3104(%rsp)
vorpd 2912(%rsp), %ymm8, %ymm2
vpcmpeqd %ymm3, %ymm3, %ymm3
vgatherqpd %ymm3, (%rsi,%ymm2,8), %ymm4
vmovupd %ymm4, 3136(%rsp)
vorpd 2816(%rsp), %ymm8, %ymm3
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm3,8), %ymm5
vmovupd %ymm5, 3808(%rsp)
vorpd 2784(%rsp), %ymm8, %ymm9
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm9,8), %ymm5
vmovupd %ymm5, 3840(%rsp)
vorpd 2752(%rsp), %ymm8, %ymm10
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm10,8), %ymm5
vmovupd %ymm5, 3872(%rsp)
vpaddq 2944(%rsp), %ymm8, %ymm4
vorpd 2720(%rsp), %ymm8, %ymm11
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rsi,%ymm11,8), %ymm5
vmovupd %ymm5, 3904(%rsp)
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3552(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm1,8), %ymm5
vmovupd %ymm5, 3616(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm2,8), %ymm1
vmovupd %ymm1, 3648(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm3,8), %ymm1
vmovupd %ymm1, 3680(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm9,8), %ymm1
vmovupd %ymm1, 3712(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm10,8), %ymm1
vmovupd %ymm1, 3744(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm11,8), %ymm1
vmovupd %ymm1, 3776(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rsi,%ymm4,8), %ymm6
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm4,8), %ymm3
vpaddq 2688(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm7
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3360(%rsp)
vpaddq 2656(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm13
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3392(%rsp)
vpaddq 2624(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm15
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3424(%rsp)
vpaddq 2592(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm9
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3456(%rsp)
vpaddq 2560(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm14
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3488(%rsp)
vpaddq 2528(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm11
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3520(%rsp)
vpaddq 2496(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm10
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3584(%rsp)
vpaddq 2464(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2432(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm12
vpaddq 2400(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3168(%rsp)
vpaddq 2368(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3200(%rsp)
vpaddq 2336(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3232(%rsp)
vpaddq 2304(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3264(%rsp)
vpaddq 2272(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3296(%rsp)
vpaddq 2240(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3328(%rsp)
vpaddq 2208(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vpaddq 2176(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 2976(%rsp)
vpaddq 2144(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3008(%rsp)
vpaddq 2112(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3072(%rsp)
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm8,8), %ymm0
vpcmpeqd %ymm5, %ymm5, %ymm5
vgatherqpd %ymm5, (%rdx,%ymm8,8), %ymm1
vmovupd 768(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm1, %ymm5
vmovupd %ymm5, 768(%rsp)
vmovupd 32(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm3, %ymm5
vmovupd %ymm5, 32(%rsp)
vmovupd 1024(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm2, %ymm5
vmovupd %ymm5, 1024(%rsp)
vmovupd 1280(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm4, %ymm5
vmovupd %ymm5, 1280(%rsp)
vmovupd 1344(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1344(%rsp)
vmovupd 480(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 480(%rsp)
vmovupd 1600(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm6, %ymm0
vmovupd %ymm0, 1600(%rsp)
vmovupd 1856(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm6, %ymm0
vmovupd %ymm0, 1856(%rsp)
vpaddq 2080(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2048(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd 800(%rsp), %ymm0
vmovupd 3552(%rsp), %ymm1
vmovupd 3040(%rsp), %ymm3
vfmadd231pd %ymm3, %ymm1, %ymm0
vmovupd %ymm0, 800(%rsp)
vmovupd 64(%rsp), %ymm0
vmovupd 3360(%rsp), %ymm5
vfmadd231pd %ymm3, %ymm5, %ymm0
vmovupd %ymm0, 64(%rsp)
vmovupd 1056(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm12, %ymm0
vmovupd %ymm0, 1056(%rsp)
vmovupd 288(%rsp), %ymm0
vmovupd 2976(%rsp), %ymm6
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 288(%rsp)
vmovupd 1376(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1376(%rsp)
vmovupd 512(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm7, %ymm0
vmovupd %ymm0, 512(%rsp)
vmovupd 1632(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm7, %ymm0
vmovupd %ymm0, 1632(%rsp)
vmovupd 1888(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1888(%rsp)
vmovupd 832(%rsp), %ymm0
vmovupd 3616(%rsp), %ymm1
vmovupd 3104(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 832(%rsp)
vmovupd 96(%rsp), %ymm0
vmovupd 3392(%rsp), %ymm3
vfmadd231pd %ymm6, %ymm3, %ymm0
vmovupd %ymm0, 96(%rsp)
vmovupd 1088(%rsp), %ymm0
vmovupd 3168(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 1088(%rsp)
vmovupd 320(%rsp), %ymm0
vmovupd 3008(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 320(%rsp)
vmovupd 1408(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm13, %ymm0
vmovupd %ymm0, 1408(%rsp)
vmovupd 544(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm13, %ymm0
vmovupd %ymm0, 544(%rsp)
vmovupd 1664(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm13, %ymm0
vmovupd %ymm0, 1664(%rsp)
vmovupd 1920(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm13, %ymm0
vmovupd %ymm0, 1920(%rsp)
vpaddq 2016(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm3
vmovupd 864(%rsp), %ymm0
vmovupd 3648(%rsp), %ymm1
vmovupd 3136(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 864(%rsp)
vmovupd 128(%rsp), %ymm0
vmovupd 3424(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 128(%rsp)
vmovupd 1120(%rsp), %ymm0
vmovupd 3200(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1120(%rsp)
vmovupd 352(%rsp), %ymm0
vmovupd 3072(%rsp), %ymm12
vfmadd231pd %ymm6, %ymm12, %ymm0
vmovupd %ymm0, 352(%rsp)
vmovupd 1440(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm15, %ymm0
vmovupd %ymm0, 1440(%rsp)
vmovupd 576(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm15, %ymm0
vmovupd %ymm0, 576(%rsp)
vmovupd 1696(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm15, %ymm0
vmovupd %ymm0, 1696(%rsp)
vmovupd 736(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm15, %ymm0
vmovupd %ymm0, 736(%rsp)
vmovupd 896(%rsp), %ymm0
vmovupd 3808(%rsp), %ymm1
vmovupd 3680(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 896(%rsp)
vmovupd 160(%rsp), %ymm0
vmovupd 3456(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 160(%rsp)
vmovupd 1152(%rsp), %ymm0
vmovupd 3232(%rsp), %ymm7
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1152(%rsp)
vmovupd 384(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 384(%rsp)
vmovupd 1472(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm9, %ymm0
vmovupd %ymm0, 1472(%rsp)
vmovupd 608(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm9, %ymm0
vmovupd %ymm0, 608(%rsp)
vmovupd 1728(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm9, %ymm0
vmovupd %ymm0, 1728(%rsp)
vmovupd -128(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm9, %ymm0
vmovupd %ymm0, -128(%rsp)
vmovupd 928(%rsp), %ymm0
vmovupd 3840(%rsp), %ymm1
vmovupd 3712(%rsp), %ymm2
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 928(%rsp)
vmovupd 192(%rsp), %ymm0
vmovupd 3488(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 192(%rsp)
vmovupd 1184(%rsp), %ymm0
vmovupd 3264(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1184(%rsp)
vmovupd 416(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 416(%rsp)
vmovupd 1504(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm14, %ymm0
vmovupd %ymm0, 1504(%rsp)
vmovupd 640(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm14, %ymm0
vmovupd %ymm0, 640(%rsp)
vmovupd 1760(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm14, %ymm0
vmovupd %ymm0, 1760(%rsp)
vmovupd -96(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm14, %ymm0
vmovupd %ymm0, -96(%rsp)
vpaddq 1984(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vmovupd 960(%rsp), %ymm0
vmovupd 3872(%rsp), %ymm1
vmovupd 3744(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 960(%rsp)
vmovupd 224(%rsp), %ymm0
vmovupd 3520(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 224(%rsp)
vmovupd 1216(%rsp), %ymm0
vmovupd 3296(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1216(%rsp)
vmovupd 448(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 448(%rsp)
vmovupd 1536(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm11, %ymm0
vmovupd %ymm0, 1536(%rsp)
vmovupd 672(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm11, %ymm0
vmovupd %ymm0, 672(%rsp)
vmovupd 1792(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm11, %ymm0
vmovupd %ymm0, 1792(%rsp)
vmovupd -64(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm11, %ymm0
vmovupd %ymm0, -64(%rsp)
vmovupd 992(%rsp), %ymm0
vmovupd 3904(%rsp), %ymm1
vmovupd 3776(%rsp), %ymm3
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 992(%rsp)
vmovupd 256(%rsp), %ymm0
vmovupd 3584(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 256(%rsp)
vmovupd 1248(%rsp), %ymm0
vmovupd 3328(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 1248(%rsp)
vmovupd 1312(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 1312(%rsp)
vmovupd 1568(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm10, %ymm0
vmovupd %ymm0, 1568(%rsp)
vmovupd 704(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm10, %ymm0
vmovupd %ymm0, 704(%rsp)
vmovupd 1824(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm10, %ymm0
vmovupd %ymm0, 1824(%rsp)
vmovupd -32(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm10, %ymm0
vmovupd %ymm0, -32(%rsp)
vpaddq 1952(%rsp), %ymm8, %ymm8
addq $-4, %rcx
jne .LBB0_5
循环很大,显然没有正确矢量化:有很多完全无用的指令,并且从内存加载似乎不是连续的(请参阅 参考资料vgatherqpd
)。Numba 不会生成好的代码,因为底层 JIT (LLVM-Lite) 无法有效地向量化代码。事实上,我发现 Clang 13.0 在一个简化示例中对类似的 C++ 代码进行了严重矢量化(GCC 和 ICC 在更复杂的代码上也失败了),而手写的 SIMD 实现效果要好得多。它看起来像是优化器的错误或至少错过了优化。这就是 Numba 代码比最优代码慢得多的原因。话虽这么说,这个实现非常有效地利用了缓存,并且是适当的多线程。
我还发现,Linux 上的 BLAS 代码比我机器上的 Windows 更快(默认包来自 PIP 和版本 1.20.3 的相同 Numpy)。因此,两者之间的差距更近method2
,method3
但后者仍然明显更快。