1

我想使用 weave.blitz 来提高以下 numpy 代码的性能:

def fastIteration(self):
    g = self.grid
    nx,ny = g.ux.shape

    uxold = g.old_ux
    ux = g.ux
    ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])

    g.setBC()
    g.old_ux = ux.copy()

在此代码中,g 是计算网格。其中包含两个不同的字段 ux 和 uxold。old 仅用于临时存储变量。在完整的代码中,大约 95% 的运行时间花费在 fastIteration 方法上,因此即使是简单的性能提升也会显着减少执行此代码所花费的时间。

numpy 方法的输出看起来像:

麻木的结果

由于这段代码是我的瓶颈,我想通过使用 weave blitz 来提高速度。这个方法看起来像:

def blitzIteration(self):
    ### does not work correct so far
    g = self.grid
    nx,ny = g.ux.shape

    uxold = g.old_ux
    ux = g.ux
    expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])"
    weave.blitz(expr, check_size=0)
    g.setBC()
    g.old_ux = ux.copy()

但是,这不会产生正确的输出: 闪电战代码的输出

4

1 回答 1

2

它看起来像一个错误weave.blitz(复制,归档和修复。那里有关于实际错误的更多信息)。

我觉得写一个完整的切片0:而不是更短的代码很奇怪,:所以我替换了所有这些切片,瞧,它奏效了。

我真的不知道错误在哪里,但expr_code生成的weave.blitz略有不同:

  • 使用时0:

    ipdb> expr_code
    'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n'
    
  • 使用时:

    ipdb> expr_code
    'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n'
    

所以,blitz::Range(0,_end)变成_all,他们以不同的方式行事。

为方便起见,这里有一个完整的脚本,可以重现问题,并且只有在问题存在时才会成功。

import numpy as np
from scipy.weave import blitz


def test_blitz_bug(N=4):
    ReI = 1.2
    ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N))
    uxold = np.random.randn(N, N)
    ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])
    expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])'
    expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])'
    blitz(expr_buggy)
    blitz(expr_not_buggy)
    assert not np.allclose(ux_blitz_buggy, ux_np)
    assert np.allclose(ux_blitz_not_buggy, ux_np)

if __name__ == '__main__':
    test_blitz_bug()
于 2013-04-26T12:17:36.413 回答