6

我有一个形状为 (M,N) 的数组 A 现在我想做这个操作

R = (A[:,newaxis,:] * A[newaxis,:,:]).sum(2)

这应该产生一个(MxM)数组。现在的问题是数组非常大,我得到一个内存错误,因为 MxMxN 数组不适合内存。

完成这项工作的最佳策略是什么?C?地图()?还是有一个特殊的功能呢?

谢谢你,大卫

4

1 回答 1

7

我不确定你的数组有多大,但以下是等价的:

R = np.einsum('ij,kj',A,A)

并且可以更快,并且内存占用更少:

In [7]: A = np.random.random(size=(500,400))

In [8]: %timeit R = (A[:,np.newaxis,:] * A[np.newaxis,:,:]).sum(2)
1 loops, best of 3: 1.21 s per loop

In [9]: %timeit R = np.einsum('ij,kj',A,A)
10 loops, best of 3: 54 ms per loop

如果我增加Ato的大小(500,4000)np.einsum则在大约 2 秒内完成计算,而原始公式由于它必须创建的临时数组的大小而使我的机器停止运行。

更新

正如@Jaime 在评论中指出的那样,这np.dot(A,A.T)也是问题的等效表述,甚至可以比np.einsum解决方案更快。完全归功于他指出这一点,但如果他没有将其作为正式解决方案发布,我想将其拉到主要答案中。

于 2013-06-07T14:39:21.243 回答