我对 numba 了解不多,但如果我们对它在幕后所做的事情做出一些基本假设,我们可以推断为什么 autojit 版本较慢以及如何通过微小的更改来加快速度......
让我们从 sum_arr 开始,
1 def sum_arr(arr):
2 z = arr.copy()
3 M = len(arr)
4 for i in range(M):
5 z[i] += arr[i]
6
7 return z
很清楚这里发生了什么,但让我们选择第 5 行,它可以重写为
1 a = arr[i]
2 b = z[i]
3 c = a + b
4 z[i] = c
Python 将进一步改变这一点
1 a = arr.__getitem__(i)
2 b = arr.__getitem__(i)
3 c = a.__add__(b)
4 z.__setitem__(i, c)
a,b 和 c 都是 numpy.int64 (或类似)的实例
我怀疑 numba 正在尝试检查这些项目的日期类型并将它们转换为一些 numba 本机数据类型(我看到的 numpy 代码最大的减速之一是无意中从 python 数据类型切换到 numpy 数据类型)。如果这确实发生了,numba 至少进行 3 次转换,2 次 numpy.int64 -> 本机,1 次本机 -> numpy.int64,或者中间体可能更糟(numpy.int64 -> python int -> 本机(c诠释))。我怀疑 numba 会在检查数据类型时增加额外的开销,可能根本不会优化循环。让我们看看如果我们从循环中删除类型更改会发生什么......
1 @autojit
2 def fast_sum_arr2(arr):
3 z = arr.tolist()
4 M = len(arr)
5 for i in range(M):
6 z[i] += arr[i]
7
8 return numpy.array(z)
第 3 行的细微变化,tolist 而不是 copy,将数据类型更改为 Python ints,但我们仍然在第 6 行有一个 numpy.int64 -> native。让我们将其重写为 z[i] += z[i]
1 @autojit
2 def fast_sum_arr3(arr):
3 z = arr.tolist()
4 M = len(arr)
5 for i in range(M):
6 z[i] += z[i]
7
8 return numpy.array(z)
随着所有的变化,我们看到了相当大的加速(尽管它不一定能击败纯 python)。当然,arr+arr 只是愚蠢的快。
1 import numpy
2 from numba import autojit
3
4 def sum_arr(arr):
5 z = arr.copy()
6 M = len(arr)
7 for i in range(M):
8 z[i] += arr[i]
9
10 return z
11
12 @autojit
13 def fast_sum_arr(arr):
14 z = arr.copy()
15 M = len(arr)
16 for i in range(M):
17 z[i] += arr[i]
18
19 return z
20
21 def sum_arr2(arr):
22 z = arr.tolist()
23 M = len(arr)
24 for i in range(M):
25 z[i] += arr[i]
26
27 return numpy.array(z)
28
29 @autojit
30 def fast_sum_arr2(arr):
31 z = arr.tolist()
32 M = len(arr)
33 for i in range(M):
34 z[i] += arr[i]
35
36 return numpy.array(z)
37
38 def sum_arr3(arr):
39 z = arr.tolist()
40 M = len(arr)
41 for i in range(M):
42 z[i] += z[i]
43
44 return numpy.array(z)
45
46 @autojit
47 def fast_sum_arr3(arr):
48 z = arr.tolist()
49 M = len(arr)
50 for i in range(M):
51 z[i] += z[i]
52
53 return numpy.array(z)
54
55 def sum_arr4(arr):
56 return arr+arr
57
58 @autojit
59 def fast_sum_arr4(arr):
60 return arr+arr
61
62 arr = numpy.arange(1000)
还有时间,
In [1]: %timeit sum_arr(arr)
10000 loops, best of 3: 129 us per loop
In [2]: %timeit sum_arr2(arr)
1000 loops, best of 3: 232 us per loop
In [3]: %timeit sum_arr3(arr)
10000 loops, best of 3: 51.8 us per loop
In [4]: %timeit sum_arr4(arr)
100000 loops, best of 3: 3.68 us per loop
In [5]: %timeit fast_sum_arr(arr)
1000 loops, best of 3: 216 us per loop
In [6]: %timeit fast_sum_arr2(arr)
10000 loops, best of 3: 65.6 us per loop
In [7]: %timeit fast_sum_arr3(arr)
10000 loops, best of 3: 56.5 us per loop
In [8]: %timeit fast_sum_arr4(arr)
100000 loops, best of 3: 2.03 us per loop