3

根据此处提供的解释1,我正在尝试使用相同的想法来加速以下积分:

import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def integrand(t, *args):
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    return c * np.exp(- (t / (a * c))**2) 

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0]) 

结合之前的参考,我尝试使用 numba/jit 并通过以下方式修改之前的块:

import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def jit_integrand_function(integrand_function):
    jitted_function = numba.jit(integrand_function, nopython=True)  
    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        return jitted_function(xx[0], xx[1])
    return LowLevelCallable(wrapped.ctypes)

@jit_integrand_function
def integrand(t, *args):
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    return c * np.exp(- (t / (a * c))**2)

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

do_integrate(integrand, 2.)

但是,这个实现给了我错误


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.

File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
    <source elided>
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    ^

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

错误来自我在被积函数内部使用来自 scipy.optimize 的 fsolve 的事实。

我想知道是否有解决此错误的方法,以及在这种情况下是否可以将 scipy.optimize.fsolve 与 numba 一起使用。

4

1 回答 1

0

我为 Minpack 编写了一个小的 python 包装器,称为NumbaMinpack,可以在 numba 编译函数中调用:https ://github.com/Nicholaswogan/NumbaMinpack 。您可以将其用于@njit被积函数:

import scipy.integrate as si
from NumbaMinpack import hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np

@cfunc(minpack_sig)
def f(x, fvec, args):
    a = args[0]
    fvec[0] = a * x[0]**2.0 - np.exp(-x[0]**2.0 / a)

funcptr = f.address # pointer to function  

@njit
def integrand(t, *args):
    a = args[0]
    args_ = np.array(args)
    x_init = np.array([1.0])
    sol = hybrd(funcptr,x_init,args_)
    c = sol[0][0]
    return c * np.exp(- (t / (a * c))**2) 

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0]) 

在我的电脑上,上面的代码需要 87 µs,而纯 python 版本需要 2920 µs

于 2021-05-09T07:45:03.690 回答