0

我的模型训练涉及对同一图像的多个变体进行编码,然后将生成的表示对图像的所有变体求和。

数据加载器生成形状的张量批次:[batch_size,num_variants,1,height,width]. 1对应于图像颜色通道。

如何在 pytorch 中使用 minibatches 训练我的模型?我正在寻找一种通过网络转发所有 batch_size×num_variant 图像并将所有变体组的结果相加的正确方法。

我目前的解决方案涉及展平前两个维度并执行 for 循环来对表示进行求和,但我觉得应该有更好的方法,而且我不确定渐变是否会记住所有内容。

4

1 回答 1

1

不确定我是否正确理解了你,但我想这就是你想要的(比如批量图像张量被调用image):

Nb, Nv, inC, inH, inW = image.shape

# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv, inC, inH, inW)

output = model(image)
_, outC, outH, outW = output.shape[1]

# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb, Nv, outC, outH, outW)

# summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
output = output.sum(dim=1, keepdim=False)

如果输入和输出通道/大小不同,我使用了inC, outC,inH等。

于 2020-11-15T03:36:56.360 回答