我试图理解为什么 numpy 的dot
函数会这样:
M = np.ones((9, 9))
V1 = np.ones((9,))
V2 = np.ones((9, 5))
V3 = np.ones((2, 9, 5))
V4 = np.ones((3, 2, 9, 5))
现在np.dot(M, V1)
并按np.dot(M, V2)
预期行事。但V3
结果V4
令我惊讶:
>>> np.dot(M, V3).shape
(9, 2, 5)
>>> np.dot(M, V4).shape
(9, 3, 2, 5)
我预计(2, 9, 5)
和(3, 2, 9, 5)
分别。另一方面,np.matmul
符合我的预期:矩阵乘法在第二个参数的前N - 2 个维度上广播,结果具有相同的形状:
>>> np.matmul(M, V3).shape
(2, 9, 5)
>>> np.matmul(M, V4).shape
(3, 2, 9, 5)
所以我的问题是:这样做的理由是
np.dot
什么?它是服务于某个特定目的,还是应用某些一般规则的结果?