我有一个我正在尝试优化的实时图像处理程序,这一切都归结为矩阵乘法。考虑我在初始化阶段计算的 3 个张量:
A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
C = np.arange(59 * 27).reshape([59, 27])
每一帧,我都会以第四张量的形式获得一个新数据:
M = np.arange(35 * 37 * 59).reshape([35, 37, 59])
.
目前,我正在计算D = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C)
,D
我想要的结果在哪里,这是程序的主要瓶颈。为了优化它,我试图遵循两个方向。
首先我尝试提出一个张量T
,我可以预先计算一个函数A, B, C, D
,然后它会全部沸腾为D = np.tensordot(M, T, axes=..)
. 我没有成功。我花了很多时间,这甚至可能吗?
此外,程序本身是用 MATLAB 编写的。由于它没有内置的张量乘法函数(einsum
或tensordot
等效函数),我目前正在使用该tprod
工具箱,并且正在执行以下操作:
temp1 = etprod('dcb', A, 'abc', M, 'adc');
temp2 = etprod('dbc', B, 'abcd', temp1, 'adb');
D = etprod('cdb', C, 'ab', temp2, 'acd');
由于 MATLAB 中的默认点积函数(用于 2D 矩阵)要快得多etprod
,因此我想A, B, C, D
以一种能够使用默认函数处理多个 2D 矩阵的方式将其重塑为 2D 数组,而无需手动编写for
循环。我也没有成功。
有什么想法吗?谢谢!