我正在使用批量矩阵乘法编写程序,可能不在通用设置下。我正在考虑以下输入:
# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3
X = torch.tensor([[1,1,1],[1,1,1],
[2,2,2],[2,2,2],[2,2,2],
[3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
[[2,2,2],[2,2,2],[2,2,2]],
[[3,3,3],[3,3,3],[3,3,3]]])
所需的输出是:
out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1) # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]], # obj0 mult with A_0
[[2,2,2]],[[2,2,2]],[[2,2,2]], # obj1 mult with A_1
[[3,3,3]],]) # obj2 mult with A_2
它实际上在 pytorch 中非常方便,只需一行!
在这里,我想改进这个程序。请注意,我使用A[idx]为每个点复制一个矩阵 A_i,因此我可以在此处使用 torch.bmm() 函数(1 个点 <-> 1 个矩阵)。Afaik,它将需要为A[idx]的中间表示分配内存。一般来说,如果我的数据批次中有 BN 对象,则 A[idx] = (B1+...+BN)*3*3 的大小可能非常大。
因此,我想知道是否可以避免矩阵 A_i 的复制。
我发现了有关 Batch Mat 的大多数先前被问到的问题。多。只假设固定的批量大小。这里问了和我一样的问题,并提供了 tensorflow 中的解决方案。但是,该解决方案是使用 tf.tile() 实现的,它也是复制矩阵。
总而言之,我的问题是关于批量矩阵乘法,同时实现:
- dynamic batch size
- input shape: (B1+...+BN) x 3
- index shape: (B1+...+BN)
- memory efficiency
- probably w/out massive replication of matrix
我在这里使用 pytorch,但我也接受其他实现。如果可以提高内存效率,我也接受在其他结构中表示输入(例如要相乘的矩阵,A)。