1

假设我有两个数组:

import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)

并希望对最后 3 个轴求和,并保留共享轴。输出维度应该是(32,6,6,20,128)。注意这里带有 20 的轴在a和中是共享的b。让我们将此轴称为“组”轴。

我有两种方法来完成这项任务:
第一种方法很简单einsum

def method1(a, b):
    return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True)  # output shape:(32,6,6,20,128)

在第二种方法中,我遍历组维度并使用einsum/tensordot计算每个组维度的结果,然后将结果堆叠:

def method2(a, b):
    result = []
    for g in range(b.shape[0]): # loop through each group dimension
        # result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
        result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True))  # output shape:(32,6,6,128)
    return np.stack(result, axis=-2)  # output shape:(32,6,6,20,128)

这是我的 jupyter notebook 中这两种方法的时间安排: 我们可以看到带有循环的第二种方法比第一种方法快。
在此处输入图像描述

我的问题是:

  1. 为什么method1慢得多?它不会计算更多的东西。
  2. 有没有更有效的方法而不使用循环?(我有点不愿意使用循环,因为它们在 python 中很慢)

谢谢你的帮助!

4

1 回答 1

4

正如@Murali 在评论中指出的那样,method1效率不是很高,因为它没有成功使用BLAS调用,而不是使用method2哪个调用。事实上,由于np.einsumOpenBLAS (Numpy 在大多数机器上使用),method1它按顺序计算结果,同时method2大部分并行运行。话虽如此,这method2是次优的,因为它没有完全使用可用的内核(部分计算是按顺序完成的)并且似乎没有有效地使用缓存。在我的 6 核机器上,它几乎不使用所有内核的 50%。


更快的实施

加快此计算的一种解决方案是为此编写高度优化的Numba并行代码。

首先,半天真的实现是使用许多 for 循环来计算 Einstein 求和并重塑输入/输出数组,以便 Numba 可以更好地优化代码(例如展开、使用SIMD 指令)。结果如下:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)

    for NHW in range(sN*sH*sW):
        for g in range(sg):
            for o in range(so):
                s = 0.0

                # Reduction
                for ihw in range(si*sh*sw):
                    s += ra[NHW, g, ihw] * rb[g, o, ihw]

                out[NHW, g, o] = s

    return out.reshape((sN, sH, sW, sg, so))

请注意,假设输入数组是连续的。如果不是这种情况,请考虑执行复制(与计算相比便宜)。

虽然上面的代码有效,但它远非高效。以下是一些可以执行的改进:

  • 并行运行最外层NHW循环;
  • 使用 Numba 标志fastmath=True。如果输入数据包含特殊值,如 NaN 或 +inf/-inf,则此标志是不安全的。但是,此标志有助于编译器使用 SIMD 指令生成更快的代码(否则这是不可能的,因为 IEEE-754 浮点运算不是关联的);
  • 交换NHW基于 - 的循环和g基于 - 的循环会导致更好的性能,因为它提高了缓存局部性(rb更有可能适合主流 CPU 的最后一级缓存,否则它可能会从 RAM 中获取);
  • 利用寄存器阻塞,使处理器的更好的 SIMD 计算单元饱和,减少内存层次的压力;
  • 通过拆分基于 - 的循环来使用平铺,因此几乎可以从较低级别的缓存(例如 L1 或 L2)中完全读取。orb

除了最后一项之外,所有这些改进都在以下代码中实现:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)

    for g in range(sg):
        for k in nb.prange((sN*sH*sW)//2):
            NHW = k*2
            so_vect_max = (so // 4) * 4

            for o in range(0, so_vect_max, 4):
                s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0

                # Useful since Numba does not optimize well the following loop otherwise
                ra_row0 = ra[NHW+0, g, :]
                ra_row1 = ra[NHW+1, g, :]
                rb_row0 = rb[g, o+0, :]
                rb_row1 = rb[g, o+1, :]
                rb_row2 = rb[g, o+2, :]
                rb_row3 = rb[g, o+3, :]

                # Highly-optimized reduction using register blocking
                for ihw in range(si*sh*sw):
                    ra_0 = ra_row0[ihw]
                    ra_1 = ra_row1[ihw]
                    rb_0 = rb_row0[ihw]
                    rb_1 = rb_row1[ihw]
                    rb_2 = rb_row2[ihw]
                    rb_3 = rb_row3[ihw]
                    s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
                    s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
                    s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
                    s12 += ra_1 * rb_2; s13 += ra_1 * rb_3

                out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
                out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
                out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
                out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13

            # Remaining part for `o`
            for o in range(so_vect_max, so):
                for ihw in range(si*sh*sw):
                    out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
                    out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]

        # Remaining part for `k`
        if (sN*sH*sW) % 2 == 1:
            k = sN*sH*sW - 1
            for o in range(so):
                for ihw in range(si*sh*sw):
                    out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]


    return out.reshape((sN, sH, sW, sg, so))

这段代码更加复杂和丑陋,但也更加高效。我没有实现平铺优化,因为它会使代码的可读性更差。但是,它应该会在多核处理器(尤其是具有小型 L2/L3 缓存的处理器)上显着加快代码速度。


性能结果

以下是我的 i5-9600KF 6 核处理器的性能结果:

method1:              816 ms
method2:              104 ms
method3:               40 ms
Theoretical optimal:    9 ms   (optimistic lower bound)

该代码比method2. 由于最佳时间比 . 好 4 倍左右,因此还有改进的余地method3

Numba 不生成快速代码的主要原因是底层 JIT 未能有效地矢量化循环。实施平铺策略应该会稍微提高执行时间,非常接近最佳时间。平铺策略对于更大的阵列至关重要。如果so更大,则尤其如此。

如果您想要更快的实现,您当然需要直接使用 SIMD 内在函数(遗憾的是不可移植)或 SIMD 库(例如 XSIMD)编写 C/C++ 本机代码。

如果您想要更快的实现,那么您需要使用更快的硬件(具有更多内核)或更专用的硬件。基于服务器的 GPU(即不是个人计算机)不应该能够加速很多这样的计算,因为您的输入很小,显然受计算限制并且大量使用 FMA 浮点运算。第一个开始是尝试cupy.einsum


底层:底层分析

为了理解为什么method1没有更快,我检查了执行的代码。这是主循环:

1a0:┌─→; Part of the reduction (see below)
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1000]
    │  
    │  ; Decrement the number of loop cycle
    │  sub        r9,0x8 
    │  
    │  ; Prefetch items so to reduce the impact 
    │  ; of the latency of reading from the RAM.
    │  prefetcht0 BYTE PTR [r8]
    │  prefetcht0 BYTE PTR [rdi]
    │  
    │  ; Part of the reduction (see below)
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1000]
    │  
    │  ; Increment iterator for the two arrays
    │  add        rdi,0x40 
    │  add        r8,0x40 
    │  
    │  ; Main computational part: 
    │  ; reduction using add+mul SSE2 instructions
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1030]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1030]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1020]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1020]
    │  addpd      xmm0,xmm1                     <--- Slow
    │  movapd     xmm1,XMMWORD PTR [rdi-0x1010]
    │  mulpd      xmm1,XMMWORD PTR [r8-0x1010]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  
    │  ; Is the loop over? 
    │  ; If not, jump to the beginning of the loop.
    ├──cmp        r9,0x7 
    └──jg         1a0

事实证明,Numpy 使用 SSE2 指令集(在所有 x86-64 处理器上都可用)。然而,我的机器,就像几乎所有相对较新的处理器一样,支持 AVX 指令集,每条指令一次可以计算两倍以上的项目。我的机器还支持在这种情况下快两倍的熔丝倍增指令 (FMA)。此外,循环清楚地受到将addpd结果累积在几乎相同的寄存器中的限制。处理器无法有效地执行它们,因为addpd需要几个延迟周期,并且在现代 x86-64 处理器上最多可以同时执行两个(这在此处是不可能的,因为一次只有一个指令可以执行累积xmm1)。

这是method2dgemmOpenBLAS的调用)的主要计算部分的执行代码:

6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi-0x80]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi-0x60]
     │  vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi-0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi-0x20]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi+0x20]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi+0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi+0x60]
     │  add          rsi,0x40
     │  add          rdi,0x100
     ├──dec          rax
     └──jne          6a40

这个循环更加优化:它利用了 AVX 指令集以及 FMA 指令集(即vfmadd231pd指令)。此外,循环更好地展开,并且没有像 Numpy 代码中那样的延迟/依赖问题。然而,虽然这个循环非常高效,但由于在 Numpy 中进行了一些顺序检查和在 OpenBLAS 中执行顺序复制,内核并没有得到有效使用。此外,我不确定在这种情况下循环是否有效地使用了缓存,因为在我的机器上的 RAM 中执行了大量的读/写操作。实际上,由于许多缓存未命中,RAM 吞吐量约为 15 GiB/s(超过 35~40 GiB/s),而吞吐量method3为 6 GiB/s(因此在缓存中完成了更多工作),执行速度明显更快。

这是主要计算部分的执行代码method3

.LBB0_5:
    vorpd   2880(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm2
    vmovupd %ymm2, 3040(%rsp)
    vorpd   2848(%rsp), %ymm8, %ymm1
    vpcmpeqd    %ymm2, %ymm2, %ymm2
    vgatherqpd  %ymm2, (%rsi,%ymm1,8), %ymm3
    vmovupd %ymm3, 3104(%rsp)
    vorpd   2912(%rsp), %ymm8, %ymm2
    vpcmpeqd    %ymm3, %ymm3, %ymm3
    vgatherqpd  %ymm3, (%rsi,%ymm2,8), %ymm4
    vmovupd %ymm4, 3136(%rsp)
    vorpd   2816(%rsp), %ymm8, %ymm3
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm3,8), %ymm5
    vmovupd %ymm5, 3808(%rsp)
    vorpd   2784(%rsp), %ymm8, %ymm9
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm9,8), %ymm5
    vmovupd %ymm5, 3840(%rsp)
    vorpd   2752(%rsp), %ymm8, %ymm10
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm10,8), %ymm5
    vmovupd %ymm5, 3872(%rsp)
    vpaddq  2944(%rsp), %ymm8, %ymm4
    vorpd   2720(%rsp), %ymm8, %ymm11
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rsi,%ymm11,8), %ymm5
    vmovupd %ymm5, 3904(%rsp)
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3552(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm1,8), %ymm5
    vmovupd %ymm5, 3616(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm2,8), %ymm1
    vmovupd %ymm1, 3648(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm3,8), %ymm1
    vmovupd %ymm1, 3680(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm9,8), %ymm1
    vmovupd %ymm1, 3712(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm10,8), %ymm1
    vmovupd %ymm1, 3744(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm11,8), %ymm1
    vmovupd %ymm1, 3776(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rsi,%ymm4,8), %ymm6
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm4,8), %ymm3
    vpaddq  2688(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm7
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3360(%rsp)
    vpaddq  2656(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm13
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3392(%rsp)
    vpaddq  2624(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm15
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3424(%rsp)
    vpaddq  2592(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm9
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3456(%rsp)
    vpaddq  2560(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm14
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3488(%rsp)
    vpaddq  2528(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm11
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3520(%rsp)
    vpaddq  2496(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm10
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3584(%rsp)
    vpaddq  2464(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2432(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm12
    vpaddq  2400(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3168(%rsp)
    vpaddq  2368(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3200(%rsp)
    vpaddq  2336(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3232(%rsp)
    vpaddq  2304(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3264(%rsp)
    vpaddq  2272(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3296(%rsp)
    vpaddq  2240(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3328(%rsp)
    vpaddq  2208(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vpaddq  2176(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 2976(%rsp)
    vpaddq  2144(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3008(%rsp)
    vpaddq  2112(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3072(%rsp)
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm8,8), %ymm0
    vpcmpeqd    %ymm5, %ymm5, %ymm5
    vgatherqpd  %ymm5, (%rdx,%ymm8,8), %ymm1
    vmovupd 768(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm1, %ymm5
    vmovupd %ymm5, 768(%rsp)
    vmovupd 32(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm3, %ymm5
    vmovupd %ymm5, 32(%rsp)
    vmovupd 1024(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm2, %ymm5
    vmovupd %ymm5, 1024(%rsp)
    vmovupd 1280(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm4, %ymm5
    vmovupd %ymm5, 1280(%rsp)
    vmovupd 1344(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1344(%rsp)
    vmovupd 480(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 480(%rsp)
    vmovupd 1600(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm6, %ymm0
    vmovupd %ymm0, 1600(%rsp)
    vmovupd 1856(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm6, %ymm0
    vmovupd %ymm0, 1856(%rsp)
    vpaddq  2080(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2048(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd 800(%rsp), %ymm0
    vmovupd 3552(%rsp), %ymm1
    vmovupd 3040(%rsp), %ymm3
    vfmadd231pd %ymm3, %ymm1, %ymm0
    vmovupd %ymm0, 800(%rsp)
    vmovupd 64(%rsp), %ymm0
    vmovupd 3360(%rsp), %ymm5
    vfmadd231pd %ymm3, %ymm5, %ymm0
    vmovupd %ymm0, 64(%rsp)
    vmovupd 1056(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm12, %ymm0
    vmovupd %ymm0, 1056(%rsp)
    vmovupd 288(%rsp), %ymm0
    vmovupd 2976(%rsp), %ymm6
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 288(%rsp)
    vmovupd 1376(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1376(%rsp)
    vmovupd 512(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm7, %ymm0
    vmovupd %ymm0, 512(%rsp)
    vmovupd 1632(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm7, %ymm0
    vmovupd %ymm0, 1632(%rsp)
    vmovupd 1888(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1888(%rsp)
    vmovupd 832(%rsp), %ymm0
    vmovupd 3616(%rsp), %ymm1
    vmovupd 3104(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 832(%rsp)
    vmovupd 96(%rsp), %ymm0
    vmovupd 3392(%rsp), %ymm3
    vfmadd231pd %ymm6, %ymm3, %ymm0
    vmovupd %ymm0, 96(%rsp)
    vmovupd 1088(%rsp), %ymm0
    vmovupd 3168(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 1088(%rsp)
    vmovupd 320(%rsp), %ymm0
    vmovupd 3008(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 320(%rsp)
    vmovupd 1408(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm13, %ymm0
    vmovupd %ymm0, 1408(%rsp)
    vmovupd 544(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm13, %ymm0
    vmovupd %ymm0, 544(%rsp)
    vmovupd 1664(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm13, %ymm0
    vmovupd %ymm0, 1664(%rsp)
    vmovupd 1920(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm13, %ymm0
    vmovupd %ymm0, 1920(%rsp)
    vpaddq  2016(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm3
    vmovupd 864(%rsp), %ymm0
    vmovupd 3648(%rsp), %ymm1
    vmovupd 3136(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 864(%rsp)
    vmovupd 128(%rsp), %ymm0
    vmovupd 3424(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 128(%rsp)
    vmovupd 1120(%rsp), %ymm0
    vmovupd 3200(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1120(%rsp)
    vmovupd 352(%rsp), %ymm0
    vmovupd 3072(%rsp), %ymm12
    vfmadd231pd %ymm6, %ymm12, %ymm0
    vmovupd %ymm0, 352(%rsp)
    vmovupd 1440(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm15, %ymm0
    vmovupd %ymm0, 1440(%rsp)
    vmovupd 576(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm15, %ymm0
    vmovupd %ymm0, 576(%rsp)
    vmovupd 1696(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm15, %ymm0
    vmovupd %ymm0, 1696(%rsp)
    vmovupd 736(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm15, %ymm0
    vmovupd %ymm0, 736(%rsp)
    vmovupd 896(%rsp), %ymm0
    vmovupd 3808(%rsp), %ymm1
    vmovupd 3680(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 896(%rsp)
    vmovupd 160(%rsp), %ymm0
    vmovupd 3456(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 160(%rsp)
    vmovupd 1152(%rsp), %ymm0
    vmovupd 3232(%rsp), %ymm7
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1152(%rsp)
    vmovupd 384(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 384(%rsp)
    vmovupd 1472(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm9, %ymm0
    vmovupd %ymm0, 1472(%rsp)
    vmovupd 608(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm9, %ymm0
    vmovupd %ymm0, 608(%rsp)
    vmovupd 1728(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm9, %ymm0
    vmovupd %ymm0, 1728(%rsp)
    vmovupd -128(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm9, %ymm0
    vmovupd %ymm0, -128(%rsp)
    vmovupd 928(%rsp), %ymm0
    vmovupd 3840(%rsp), %ymm1
    vmovupd 3712(%rsp), %ymm2
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 928(%rsp)
    vmovupd 192(%rsp), %ymm0
    vmovupd 3488(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 192(%rsp)
    vmovupd 1184(%rsp), %ymm0
    vmovupd 3264(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1184(%rsp)
    vmovupd 416(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 416(%rsp)
    vmovupd 1504(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm14, %ymm0
    vmovupd %ymm0, 1504(%rsp)
    vmovupd 640(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm14, %ymm0
    vmovupd %ymm0, 640(%rsp)
    vmovupd 1760(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm14, %ymm0
    vmovupd %ymm0, 1760(%rsp)
    vmovupd -96(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm14, %ymm0
    vmovupd %ymm0, -96(%rsp)
    vpaddq  1984(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vmovupd 960(%rsp), %ymm0
    vmovupd 3872(%rsp), %ymm1
    vmovupd 3744(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 960(%rsp)
    vmovupd 224(%rsp), %ymm0
    vmovupd 3520(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 224(%rsp)
    vmovupd 1216(%rsp), %ymm0
    vmovupd 3296(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1216(%rsp)
    vmovupd 448(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 448(%rsp)
    vmovupd 1536(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm11, %ymm0
    vmovupd %ymm0, 1536(%rsp)
    vmovupd 672(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm11, %ymm0
    vmovupd %ymm0, 672(%rsp)
    vmovupd 1792(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm11, %ymm0
    vmovupd %ymm0, 1792(%rsp)
    vmovupd -64(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm11, %ymm0
    vmovupd %ymm0, -64(%rsp)
    vmovupd 992(%rsp), %ymm0
    vmovupd 3904(%rsp), %ymm1
    vmovupd 3776(%rsp), %ymm3
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 992(%rsp)
    vmovupd 256(%rsp), %ymm0
    vmovupd 3584(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 256(%rsp)
    vmovupd 1248(%rsp), %ymm0
    vmovupd 3328(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 1248(%rsp)
    vmovupd 1312(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 1312(%rsp)
    vmovupd 1568(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm10, %ymm0
    vmovupd %ymm0, 1568(%rsp)
    vmovupd 704(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm10, %ymm0
    vmovupd %ymm0, 704(%rsp)
    vmovupd 1824(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm10, %ymm0
    vmovupd %ymm0, 1824(%rsp)
    vmovupd -32(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm10, %ymm0
    vmovupd %ymm0, -32(%rsp)
    vpaddq  1952(%rsp), %ymm8, %ymm8
    addq    $-4, %rcx
    jne .LBB0_5

循环很大,显然没有正确矢量化:有很多完全无用的指令,并且从内存加载似乎不是连续的(请参阅 参考资料vgatherqpd)。Numba 不会生成好的代码,因为底层 JIT (LLVM-Lite) 无法有效地向量化代码。事实上,我发现 Clang 13.0 在一个简化示例中对类似的 C++ 代码进行了严重矢量化(GCC 和 ICC 在更复杂的代码上也失败了),而手写的 SIMD 实现效果要好得多。它看起来像是优化器的错误或至少错过了优化。这就是 Numba 代码比最优代码慢得多的原因。话虽这么说,这个实现非常有效地利用了缓存,并且是适当的多线程。

我还发现,Linux 上的 BLAS 代码比我机器上的 Windows 更快(默认包来自 PIP 和版本 1.20.3 的相同 Numpy)。因此,两者之间的差距更近method2method3但后者仍然明显更快。

于 2022-01-30T01:41:16.007 回答