我正在尝试在张量之间进行详尽的连接。因此,例如,我有张量:
a = torch.randn(3, 512)
我想连接像 concat(t1,t1),concat(t1,t2), concat(t1,t3), concat(t2,t1), concat(t2,t2)....
作为一个天真的解决方案,我使用了for
循环:
ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)
for i in range(len(split)):
ans.append(split[i])
for t1 in ans:
for t2 in ans:
result.append(torch.cat((t1,t2), dim=1))
问题是每个时代都需要很长时间,而且代码很慢。我尝试了在PyTorch上发布的解决方案:How to implement attention for graph attention layer但这会产生内存错误。
t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))
我确信有一种更快的方法,但我无法弄清楚。