你没有说太多,SG_Z
但我怀疑它是 2d (或更高)。 numba
具有有限的多维索引能力(相比numpy
)
In [133]: arr = np.random.rand(3,4)
In [134]: arr
Out[134]:
array([[0.8466427 , 0.37340328, 0.07712635, 0.34466743],
[0.86591184, 0.32048868, 0.1260246 , 0.9811717 ],
[0.28948191, 0.32099879, 0.54819722, 0.78863841]])
In [135]: arr<.5
Out[135]:
array([[False, True, True, True],
[False, True, True, False],
[ True, True, False, False]])
In [136]: arr[arr<.5]
Out[136]:
array([0.37340328, 0.07712635, 0.34466743, 0.32048868, 0.1260246 ,
0.28948191, 0.32099879])
numba
:_
In [137]: import numba
In [138]: @numba.njit
...: def foo(arr, thresh):
...: arr[arr<.5]=0
...: return arr
...:
In [139]: foo(arr,.5)
Traceback (most recent call last):
File "<ipython-input-139-33ea2fda1ea2>", line 1, in <module>
foo(arr,.5)
File "/usr/local/lib/python3.8/dist-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
error_rewrite(e, 'typing')
File "/usr/local/lib/python3.8/dist-packages/numba/core/dispatcher.py", line 361, in error_rewrite
raise e.with_traceback(None)
TypingError: No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(float64, 2d, C), array(bool, 2d, C), Literal[int](0))
There are 16 candidate implementations:
- Of which 14 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 2d, C), array(bool, 2d, C), int64)':
No match.
- Of which 2 did not match due to:
Overload in function 'SetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 171.
With argument(s): '(array(float64, 2d, C), array(bool, 2d, C), int64)':
Rejected as the implementation raised a specific error:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from /usr/local/lib/python3.8/dist-packages/numba/core/typing/arraydecl.py:68
During: typing of setitem at <ipython-input-138-6861f217f595> (3)
一般来说,它不会setitem
丢失。numba
一直这样做。这是setitem
针对这种特殊的论点组合。
如果我首先解开阵列,它确实有效。
In [140]: foo(arr.ravel(),.5)
Out[140]:
array([0.8466427 , 0. , 0. , 0. , 0.86591184,
0. , 0. , 0.9811717 , 0. , 0. ,
0.54819722, 0.78863841])
但是numba
我们不需要害怕迭代,所以对于 2d 输入,我们可以对行进行迭代:
In [148]: @numba.njit
...: def foo(arr, thresh):
...: for i in arr:
...: i[i<thresh] = 0
...: return arr
...:
In [149]: foo(arr,.5)
Out[149]:
array([[0.8466427 , 0. , 0. , 0. ],
[0.86591184, 0. , 0. , 0.9811717 ],
[0. , 0. , 0.54819722, 0.78863841]])
可能有更通用的方式来写这个,并提供签名,但这应该给出一些关于如何解决这个问题的想法。