In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
np.tensordot(a,a,axes=2)
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
在我之前的帖子中,我推断出axis=2
转化为axes=([-2,-1],[0,1])
numpy.tensordot 函数如何逐步工作?
In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
np.tensordot(a,a,axes=([-2,-1],[0,1]))
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
a
所以这是试图对第一个的最后两个维度和第二个的前两个维度进行双轴缩减a
。这a
就是尺寸不匹配。显然,这axes
是为 2d 数组设计的,没有过多考虑 3d 数组。这不是 3 轴缩减。
一些开发人员认为这些单个数字轴值会很方便,但这并不意味着它们经过了严格的考虑或测试。
元组轴为您提供更多控制:
In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]:
array([[ 880, 940, 1000, 1060],
[ 940, 1006, 1072, 1138],
[1000, 1072, 1144, 1216],
[1060, 1138, 1216, 1294]])