12

好的,所以我有几个 sympy 对象(表达式)的多维 numpy 数组。例如:

A = array([[1.0*cos(z0)**2 + 1.0, 1.0*cos(z0)],
          [1.0*cos(z0), 1.00000000000000]], dtype=object)

等等。

我想做的是使用 einsum 将这些数组中的几个相乘,因为我已经从我之前做的数值计算中获得了它的语法。问题是,当我尝试做类似的事情时

einsum('ik,jkim,j', A, B, C)

我收到一个类型错误:

TypeError: invalid data type for einsum

当然,所以在谷歌上的快速搜索显示我 einsum 可能无法做到这一点,但没有理由说明原因。特别是,在这些数组上调用 numpy.dot() 和 numpy.tensordot() 函数就像一个魅力。我可以使用 tensordot 来做我需要的事情,但是当我想到必须用嵌套的 tensordot 调用替换 50 个左右的 Einsten 求和(如上面的那个(其中 indeces 的顺序非常重要))时,我的大脑很痛。更可怕的是不得不调试该代码并寻找那个放错位置的索引交换。

长话短说,有谁知道为什么 tensordot 可以处理对象但 einsum 不能?对解决方法有什么建议吗?如果没有,关于我将如何编写自己的包装器来嵌套 tensordot 调用的任何建议,这有点类似于 einsum 表示法(数字而不是字母很好)?

4

3 回答 3

4

Einsum 基本上取代了 tensordot(不是 dot,因为 dot 通常使用优化的线性代数包),在代码方面它完全不同。

这是一个对象 einsum,它未经测试(对于更复杂的事情),但我认为它应该可以工作......在 C 中做同样的事情可能更简单,因为你可以从真正的 einsum 函数中窃取除循环本身之外的所有内容。因此,如果您愿意,请实施它并让更多人开心...

https://gist.github.com/seberg/5236560

我不会保证任何事情,尤其是对于更奇怪的极端情况。当然,您也可以将 einsum 表示法转换为 tensordot 表示法,我敢肯定,这可能会更快一些,因为循环最终大部分都在 C 中......

于 2013-03-25T11:38:57.110 回答
4

有趣的是,添加optimize="optimal"对我有用

einsum('ik,jkim,j', A, B, C)产生错误,但是

einsum('ik,jkim,j', A, B, C, optimize="optimal")与 sympy 完美配合。

于 2021-02-28T13:55:00.233 回答
2

这是一个更简单的实现,它将einsumin 多个tensordots 分开。

def einsum(string, *args):
    index_groups = map(list, string.split(','))
    assert len(index_groups) == len(args)
    tensor_indices_tuples = zip(index_groups, args)
    return reduce(einsum_for_two, tensor_indices_tuples)[1]

def einsum_for_two(tensor_indices1, tensor_indices2):
    string1, tensor1 = tensor_indices1
    string2, tensor2 = tensor_indices2
    sum_over_indices = set(string1).intersection(set(string2))
    new_string = string1 + string2
    axes = ([], [])
    for i in sum_over_indices:
        new_string.remove(i)
        new_string.remove(i)
        axes[0].append(string1.index(i))
        axes[1].append(string2.index(i))
    return new_string, np.tensordot(tensor1, tensor2, axes)

首先,它将einsum(索引,张量)元组中的参数分开。然后它减少列表如下:

  • 获取前两个元组,并对einsum_for_two它们进行简单计算。它还打印出新的索引签名。
  • 的值einsum_for_two与列表中的下一个元组一起用作 的新参数einsum_for_two
  • 继续直到只剩下元组。索引签名被丢弃,只返回张量。

它可能很慢(但无论如何你正在使用object dtype)。它不会对输入进行很多正确性检查。

正如@seberg 所指出的,我的代码不适用于张量的痕迹。

于 2013-03-25T16:26:59.807 回答