3

我一直在尝试学习 Cython 来加快我的一些计算。这是我正在尝试做的一个子集:这只是在使用 NumPy 数组的同时使用递归公式对微分方程进行积分。与纯 python 版本相比,我已经实现了约 100 倍的速度提升。-a但是,通过查看 cython 命令为我的代码生成的 HTML 文件,我似乎可以提高速度。我的代码如下(在 HTML 文件中我想变成白色的行被标记为黄色):

%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport exp,sqrt

@cython.boundscheck(False)
cdef double riccati_int(double j, double w, double h, double an, double d):
    cdef:
        double W
        double an1
    W = sqrt(w**2 + d**2)
    #dark_yellow
    an1 = ((d - (W + w) * an) * exp(-2 * W * h / j ) - d - (W - w) * an) / 
          ((d * an - W + w) * exp(-2 * W * h / j) - d * an - W - w) 
    return an1


def acalc(double j, double w):
    cdef:
        int xpos, i, n
        np.ndarray[np.int_t, ndim=1] xvals
        np.ndarray[np.double_t, ndim=1] h, a
    xpos = 74
    xvals = np.array([0, 8, 23, 123, 218], dtype=np.int)     #dark_yellow
    h = np.array([1, .1, .01, .1], dtype=np.double)          #dark_yellow
    a = np.empty(219, dtype=np.double)                       #dark_yellow
    a[0] = 1 / (w + sqrt(w**2 + 1))                          #light_yellow

    for i in range(h.size):                                  #dark_yellow
        for n in range(xvals[i], xvals[i + 1]):              #light_yellow
            if n < xpos:
                a[n+1] = riccati_int(j, w, h[i], a[n], 1.)   #light_yellow
            else:
                a[n+1] = riccati_int(j, w, h[i], a[n], 0.)   #light_yellow
    return a  

在我看来,我上面标记的所有 9 行都应该能够通过适当的调整变成白色。一个问题是能够以正确的方式定义 NumPy 数组。但可能更重要的是让第一条标记线有效工作的能力,因为这是完成大部分计算的地方。我尝试阅读 HTML 文件在单击黄线后显示的生成的 C 代码,但老实说,我不知道如何阅读该代码。如果有人可以帮助我,将不胜感激。

4

2 回答 2

1

我认为您不需要关心不在循环中的黄线。添加以下编译器指令将使循环中的三行更快:

@cython.cdivision(True)
cdef double riccati_int(double j, double w, double h, double an, double d):
    pass

@cython.boundscheck(False)
@cython.wraparound(False)
def acalc(double j, double w):
    pass
于 2013-08-05T04:02:46.480 回答
0

我不确定它是否有所作为,但你可以对数组使用内存视图,例如

cdef double [:] h = np.array([1, .1, .01, .1], dtype=np.double) #dark_yellow
cdef double [:] a = np.empty(219, dtype=np.double)              #dark_yellow

同样为四个静态值创建一个 numpy 数组有点过头了。这可以用静态 C 数组替换

cdef double *h = [1, .1, .01, .1]

但是,如前所述,循环中的内容是最重要的。由于 line profiler 不适用于 cython (afaik) 使用time模块在函数内进行基准测试,除了使用cProfile. 它可能会给您一个想法,必须在上下文中评估 cython 日志中线条颜色的强度。

我所知,建议使用 python 类型进行索引

size_t i, n
Py_ssize_t i, n

第二个是签名版

于 2014-01-08T20:37:56.943 回答