0

这是这个问题的后续:

如何求和的平方和的和?

我在哪里寻求帮助以使用 einsum(以实现极大的速度提高)并得到了很好的答案。

我也得到了使用的建议numba。我已经尝试过两者,似乎在某个点之后速度增加numba要好得多。

那么如何在不遇到内存问题的情况下加快速度呢?

4

2 回答 2

2

下面的解决方案介绍了 3 种不同的方法来做简单的总和和 4 种不同的方法来做平方和。

总和 3 种方法 - for 循环、JIT for 循环、einsum(没有遇到内存问题)

和的平方和 4 种方法 - for 循环、JIT for 循环、扩展 einsum、中间 einsum

这里前三个不会遇到内存问题,而 for 循环和扩展的 einsum 会遇到速度问题。这使得 JIT 解决方案看起来是最好的。

import numpy as np
import time
from numba import jit

def fun1(Fu, Fv, Fx, Fy, P, B):
    Nu = Fu.shape[0]
    Nv = Fv.shape[0]
    Nx = Fx.shape[0]
    Ny = Fy.shape[0]
    Nk = Fu.shape[1]
    Nl = Fv.shape[1]
    I1 = np.zeros([Nu, Nv])
    for iu in range(Nu):
        for iv in range(Nv):
            for ix in range(Nx):
                for iy in range(Ny):
                    S = 0.
                    for ik in range(Nk):
                        for il in range(Nl):
                            S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
                    I1[iu, iv] += S
    return I1

def fun2(Fu, Fv, Fx, Fy, P, B):
    Nu = Fu.shape[0]
    Nv = Fv.shape[0]
    Nx = Fx.shape[0]
    Ny = Fy.shape[0]
    Nk = Fu.shape[1]
    Nl = Fv.shape[1]
    I2 = np.zeros([Nu, Nv])
    for iu in range(Nu):
        for iv in range(Nv):
            for ix in range(Nx):
                for iy in range(Ny):
                    S = 0.
                    for ik in range(Nk):
                        for il in range(Nl):
                            S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
                    I2[iu, iv] += S**2.
    return I2

if __name__ == '__main__':

    Nx = 30
    Ny = 40
    Nk = 50
    Nl = 60
    Nu = 70
    Nv = 8
    Fx = np.random.rand(Nx, Nk)
    Fy = np.random.rand(Ny, Nl)
    Fu = np.random.rand(Nu, Nk)
    Fv = np.random.rand(Nv, Nl)
    P = np.random.rand(Nx, Ny)
    B = np.random.rand(Nk, Nl)
    fjit1 = jit(fun1)
    fjit2 = jit(fun2)

    # For loop - becomes too slow so commented out
    # t = time.time()
    # I1 = fun1(Fu, Fv, Fx, Fy, P, B)
    # print 'fun1    :', time.time() - t

    # JIT compiled for loop - After a certain point beats einsum
    t = time.time()
    I1jit = fjit1(Fu, Fv, Fx, Fy, P, B)
    print 'jit1    :', time.time() - t

    # einsum great solution when no squaring is needed
    t = time.time()
    I1_ = np.einsum('uk, vl, xk, yl, xy, kl->uv', Fu, Fv, Fx, Fy, P, B)
    print '1 einsum:', time.time() - t

    # For loop - becomes too slow so commented out
    # t = time.time()
    # I2 = fun2(Fu, Fv, Fx, Fy, P, B)
    # print 'fun2    :', time.time() - t

    # JIT compiled for loop - After a certain point beats einsum
    t = time.time()
    I2jit = fjit2(Fu, Fv, Fx, Fy, P, B)
    print 'jit2    :', time.time() - t

    # Expanded einsum - As the size increases becomes very very slow
    # t = time.time()
    # I2_ = np.einsum('uk,vl,xk,yl,um,vn,xm,yn,kl,mn,xy->uv', Fu,Fv,Fx,Fy,Fu,Fv,Fx,Fy,B,B,P**2)
    # print '2 einsum:', time.time() - t

    # Intermediate einsum - As the sizes increase memory can become an issue
    t = time.time()
    temp = np.einsum('uk, vl, xk, yl, xy, kl->uvxy', Fu, Fv, Fx, Fy, P, B)
    I2__ = np.einsum('uvxy->uv', np.square(temp))
    print '2 einsum:', time.time() - t

    # print 'I1 == I1_   :', np.allclose(I1, I1_)
    print 'I1_ == Ijit1_:', np.allclose(I1_, I1jit)
    # print 'I2 == I2_   :', np.allclose(I2, I2_)
    print 'I2_ == Ijit2_:', np.allclose(I2__, I2jit)

评论:请随时编辑/改进此答案。如果有人对制作这种平行有任何建议,那就太好了。

于 2014-12-11T19:46:07.980 回答
1

您可以先对一个索引求和,然后再继续乘法。我也尝试过在最后的乘法和减法运算中加入 numexpr 的版本,但它似乎没有太大帮助。

def fun3(Fu, Fv, Fx, Fy, P, B):
    P = P[None, None, ...]
    Fu = Fu[:, None, None, None, :]
    Fx = Fx[None, None, :, None, :]
    Fv = Fv[:, None, None, :]
    Fy = Fy[None, :, None, :]
    B = B[None, None, ...]
    return np.sum((P*np.sum(Fu*Fx*np.sum(Fv*Fy*B, axis=-1)[None, :, None, :, :], axis=-1))**2, axis=(2, 3))

在我的电脑上速度要快得多:

jit2:7.06 秒

乐趣 3:0.144 秒

编辑:小改进 - 先乘然后平方。

Edit2:利用每个最擅长的(numexpr - 乘法,numpy - 点/张量点,求和),您仍然可以将 fun3 提高 20 倍以上。

def fun4(Fu, Fv, Fx, Fy, P, B):
    P = P[None, None, ...]
    Fu = Fu[:, None, :]
    Fx = Fx[None, ...]
    Fy = Fy[:, None, :]
    B = B[None, ...]
    s = ne.evaluate('Fu*Fx')
    r = np.tensordot(Fv, ne.evaluate('Fy*B'), axes=(1, 2))
    I = np.tensordot(s, r, axes=(2, 2)).swapaxes(1, 2)
    r = ne.evaluate('(P*I)**2')
    r = np.sum(r, axis=(2, 3))
    return r

fun4: 0.007 秒

此外,由于 fun8 不再那么耗费内存(由于智能张量点),您可以将更大的数组相乘并看到它使用多个内核。

于 2014-12-17T09:44:58.293 回答