随着最近对 Numpy (1.14) 的更新,我发现它破坏了我的整个代码库。这是基于将默认的 numpy einsum 优化参数从 False 更改为 True。
结果,以下基本操作现在失败:
a = np.random.random([50, 2, 2])
b = np.random.random([50, 2])
np.einsum('bdc, ac -> ab', a, b, optimize=True)
带有以下错误跟踪:
ValueError Traceback (most recent call last)
<ipython-input-71-b0f9ce3c71a3> in <module>()
----> 1 np.einsum('bdc, ac -> ab', a, b, optimize=True)
C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\einsumfunc.py in
einsum(*operands, **kwargs)
1118
1119 # Contract!
-> 1120 new_view = tensordot(*tmp_operands, axes=
(tuple(left_pos), tuple(right_pos)))
1121
1122 # Build a new view if needed
C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\numeric.py in
tensordot(a, b, axes)
1301 oldb = [bs[axis] for axis in notin]
1302
-> 1303 at = a.transpose(newaxes_a).reshape(newshape_a)
1304 bt = b.transpose(newaxes_b).reshape(newshape_b)
1305 res = dot(at, bt)
ValueError: axes don't match array
我向 einsum 请求的操作看起来很简单……那为什么会失败呢?如果我设置“optimize=False”,它工作得很好。
我尝试使用 einsum_path 进行探索,但生成的路径信息在优化和不优化的情况下是相同的。