1

在 Numba 中进行类型推断时,我肯定遗漏了一些东西。我写了这个小样本,似乎无法找出编译时无法推断类型的原因。您可以通过尝试运行它来重现:

import numpy as np
from numba import njit, prange


@njit(["void(uint8[::1], int16)"])
def run_a(arr, numeric):
    final_result = run_b(arr, numeric)   # <<== FAILS HERE
    print(final_result)


@njit(["float64(uint8[::1], int16)"], fastmath=True, nogil=True)
def run_b(arr, numeric):
    value1 = 1.0
    value_array = np.zeros(numeric, np.float64)

    return value1 + value_array.sum()


array = np.arange(10, dtype=np.uint8)
run_a(array, 2)

它失败了:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'run_b': cannot determine Numba type of <class 'numba.core.ir.UndefinedType'>

File "tests.py", line 7:
def run_a(arr, numeric):
    final_result = run_b(arr, numeric)
    ^

为什么它无法推断该调用中的类型?我什至对函数签名进行了注释,所以据我所知,它们是什么类型是毫无疑问的。

顺便说一句,我知道 Numba 可能不会获得巨大的收益。这只是一个示例代码,我试图了解问题所在。

我错过了什么?如何让它编译?

4

1 回答 1

0

我发现了问题......必须提升函数,否则它们将无法编译。因此,颠倒顺序以便 run_b 出现在 run_a 之前解决了问题。

于 2020-07-11T23:58:12.497 回答