2

我试图理解一个 python 代码,它用于numpy.einsum()将 4 维 numpy 数组 , 转换为 2 维A或 3 维数组。传递给的下标numpy.einsum()如下:

Mat1 = np.einsum('aabb->ab', A) 

Mat2 = np.einsum('abab->ab', A)

Mat3 = np.einsum('abba->ab', A) 

T1 = np.einsum('abcb->abc' A)

T2 = np.einsum('abbc->abc', A)

等等。按照(Understanding NumPy's einsum)和(Python - Sum 4D Array)的答案,例如,我试图用它numpy.sum()来理解上述下标的含义,Mat1 = np.sum(A, axis=(0,3))但我无法重现我得到的结果numpy.einsum()。有人可以解释一下这些下标是如何解释的numpy.einsum()吗?

4

1 回答 1

2

我建议您阅读Wikipedia 上的 Einstein notation

以下是对您问题的简短回答:

np.einsum('aabb->ab', A)

方法:

res = np.empty((max_a, max_b), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    res[a, b] = A[a, a, b, b]
return res

简短解释:
aabb表示索引及其相等性(参见A[a, a, b, b]);
->ab表示形状是(max_a, max_b)并且您不需要两个对这两个索引进行总和。(如果他们c也是,那么你应该总结一切,c因为它没有在之后出现->


其他你的例子:

np.einsum('abab->ab', A)

# Same as (by logic, not by actual code)

res = np.empty((max_a, max_b), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    res[a, b] = A[a, b, a, b]
return res
np.einsum('abba->ab', A) 

# Same as (by logic, not by actual code)

res = np.empty((max_a, max_b), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    res[a, b] = A[a, b, b, a]
return res
np.einsum('abcb->abc', A)

# Same as (by logic, not by actual code)

res = np.empty((max_a, max_b, max_c), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    for c in range(max_c):
      res[a, b, c] = A[a, b, c, b]
return res
np.einsum('abbc->abc', A)

# Same as (by logic, not by actual code)

res = np.empty((max_a, max_b, max_c), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    for c in range(max_c):
      res[a, b, c] = A[a, b, b, c]
return res

一些代码来检查它是否真的是真的:

import numpy as np


max_a = 2
max_b = 3
max_c = 5

shape_1 = (max_a, max_b, max_c, max_b)
A = np.arange(1, np.prod(shape_1) + 1).reshape(shape_1)

print(A)
print()
print(np.einsum('abcb->abc', A))
print()

res = np.empty((max_a, max_b, max_c), dtype=A.dtype)
for a in range(max_a):
  for b in range(max_b):
    for c in range(max_c):
      res[a, b, c] = A[a, b, c, b]

print(res)
print()

于 2019-05-17T23:16:13.897 回答