算术numpy表达式的内存消耗是多少?
vec ** 3 + vec ** 2 + vec
(vec 是一个 numpy.ndarray)。是否为每个中间操作存储了一个数组?这种复合表达式的内存是否是底层 ndarray 的数倍?
算术numpy表达式的内存消耗是多少?
vec ** 3 + vec ** 2 + vec
(vec 是一个 numpy.ndarray)。是否为每个中间操作存储了一个数组?这种复合表达式的内存是否是底层 ndarray 的数倍?
您是对的,将为每个中间结果分配一个新数组。幸运的是,该软件包numexpr
旨在处理此问题。从描述:
NumExpr 获得比 NumPy 更好的性能的主要原因是它避免为中间结果分配内存。这会导致更好的高速缓存利用率并总体上减少内存访问。因此,NumExpr 最适用于大型数组。
例子:
In [97]: xs = np.random.rand(1_000_000)
In [98]: %timeit xs ** 3 + xs ** 2 + xs
26.8 ms ± 371 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [99]: %timeit numexpr.evaluate('xs ** 3 + xs ** 2 + xs')
1.43 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
感谢@max9111 指出 numexpr 简化了乘法运算。似乎基准测试中的大部分差异都可以通过优化来解释xs ** 3
。
In [421]: %timeit xs * xs
1.62 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [422]: %timeit xs ** 2
1.63 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [423]: %timeit xs ** 3
22.8 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [424]: %timeit xs * xs * xs
2.52 ms ± 58.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)