7

我对 numpy 的广播规则有点困惑。假设您想要执行更高维度数组的轴方向标量积以将数组维度减少一(基本上是沿一个轴执行加权求和):

from numpy import *

A = ones((3,3,2))
v = array([1,2])

B = zeros((3,3))

# V01: this works
B[0,0] = v.dot(A[0,0])

# V02: this works
B[:,:] = v[0]*A[:,:,0] + v[1]*A[:,:,1] 

# V03: this doesn't
B[:,:] = v.dot(A[:,:]) 

为什么 V03 不起作用?

干杯

4

3 回答 3

4

np.dot(a, b)a 的最后一个轴和 b 的倒数第二个轴上运行。因此,对于您问题中的特定情况,您可以随时选择:

>>> a.dot(v)
array([[ 3.,  3.,  3.],
       [ 3.,  3.,  3.],
       [ 3.,  3.,  3.]])

如果要保持v.dot(a)顺序,则需要使轴就位,这可以通过以下方式轻松实现np.rollaxis

>>> v.dot(np.rollaxis(a, 2, 1))
array([[ 3.,  3.,  3.],
       [ 3.,  3.,  3.],
       [ 3.,  3.,  3.]])

我不太喜欢,除非是为了明显的矩阵或向量乘法,因为在使用可选参数np.dot时对输出dtype非常严格。outJoe Kington 已经提到过它,但如果你要做这类事情,请习惯np.einsum:一旦你掌握了语法,它就会将你花在重塑事物上的时间减少到最低限度:

>>> a = np.ones((3, 3, 2))
>>> np.einsum('i, jki', v, a)
array([[ 3.,  3.,  3.],
       [ 3.,  3.,  3.],
       [ 3.,  3.,  3.]])

并不是说它在这种情况下太相关,但它也快得离谱:

In [4]: %timeit a.dot(v)
100000 loops, best of 3: 2.43 us per loop

In [5]: %timeit v.dot(np.rollaxis(a, 2, 1))
100000 loops, best of 3: 4.49 us per loop

In [7]: %timeit np.tensordot(v, a, axes=(0, 2))
100000 loops, best of 3: 14.9 us per loop

In [8]: %timeit np.einsum('i, jki', v, a)
100000 loops, best of 3: 2.91 us per loop
于 2013-03-08T20:40:43.980 回答
3

tensordot在这种特殊情况下,您也可以使用,。

import numpy as np

A = np.ones((3,3,2))
v = np.array([1,2])

print np.tensordot(v, A, axes=(0, 2))

这产生:

array([[ 3.,  3.,  3.],
       [ 3.,  3.,  3.],
       [ 3.,  3.,  3.]])

axes=(0,2)表示tensordot应该对 中的第一个轴和v中的第三个轴求和A。(还可以查看einsum,它更灵活,但如果您不习惯这种符号,则更难理解。)

如果tensordot考虑速度,则比apply_along_axes用于小型阵列要快得多。

In [14]: A = np.ones((3,3,2))

In [15]: v = np.array([1,2])

In [16]: %timeit np.tensordot(v, A, axes=(0, 2))
10000 loops, best of 3: 21.6 us per loop

In [17]: %timeit np.apply_along_axis(v.dot, 2, A)
1000 loops, best of 3: 258 us per loop

(由于持续的开销,对于大型阵列,差异不太明显,但tensordot始终更快。)

于 2013-03-08T17:59:04.900 回答
2

你可以用numpy.apply_along_axis()这个:

In [35]: np.apply_along_axis(v.dot, 2, A)
Out[35]: 
array([[ 3.,  3.,  3.],
       [ 3.,  3.,  3.],
       [ 3.,  3.,  3.]])

我认为V03不起作用的原因是它没有什么不同:

B[:,:] = v.dot(A) 

即它试图计算沿最外轴的点积A

于 2013-03-08T15:07:53.617 回答