-1

我知道有很多关于 tensordot 的问题,我已经浏览了一些 15 页的迷你书答案,我确信这些答案是人们花费数小时制作的,但我还没有找到关于什么的解释axes=2

这让我觉得np.tensordot(b,c,axes=2) == np.sum(b * c),但作为一个数组:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)

但后来这失败了:

a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)

如果有人可以提供一个简短、简洁的解释np.tensordot(x,y,axes=2),并且只有axes=2,那么我很乐意接受。

4

1 回答 1

0
In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
  File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
    np.tensordot(a,a,axes=2)
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

在我之前的帖子中,我推断出axis=2转化为axes=([-2,-1],[0,1])

numpy.tensordot 函数如何逐步工作?

In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
  File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
    np.tensordot(a,a,axes=([-2,-1],[0,1]))
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

a所以这是试图对第一个的最后两个维度和第二个的前两个维度进行双轴缩减a。这a就是尺寸不匹配。显然,这axes是为 2d 数组设计的,没有过多考虑 3d 数组。这不是 3 轴缩减。

一些开发人员认为这些单个数字轴值会很方便,但这并不意味着它们经过了严格的考虑或测试。

元组轴为您提供更多控制:

In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]: 
array([[ 880,  940, 1000, 1060],
       [ 940, 1006, 1072, 1138],
       [1000, 1072, 1144, 1216],
       [1060, 1138, 1216, 1294]])
于 2021-02-15T21:20:39.810 回答