在 Numba 中进行类型推断时,我肯定遗漏了一些东西。我写了这个小样本,似乎无法找出编译时无法推断类型的原因。您可以通过尝试运行它来重现:
import numpy as np
from numba import njit, prange
@njit(["void(uint8[::1], int16)"])
def run_a(arr, numeric):
final_result = run_b(arr, numeric) # <<== FAILS HERE
print(final_result)
@njit(["float64(uint8[::1], int16)"], fastmath=True, nogil=True)
def run_b(arr, numeric):
value1 = 1.0
value_array = np.zeros(numeric, np.float64)
return value1 + value_array.sum()
array = np.arange(10, dtype=np.uint8)
run_a(array, 2)
它失败了:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'run_b': cannot determine Numba type of <class 'numba.core.ir.UndefinedType'>
File "tests.py", line 7:
def run_a(arr, numeric):
final_result = run_b(arr, numeric)
^
为什么它无法推断该调用中的类型?我什至对函数签名进行了注释,所以据我所知,它们是什么类型是毫无疑问的。
顺便说一句,我知道 Numba 可能不会获得巨大的收益。这只是一个示例代码,我试图了解问题所在。
我错过了什么?如何让它编译?