我的模型训练涉及对同一图像的多个变体进行编码,然后将生成的表示对图像的所有变体求和。
数据加载器生成形状的张量批次:[batch_size,num_variants,1,height,width]
. 1
对应于图像颜色通道。
如何在 pytorch 中使用 minibatches 训练我的模型?我正在寻找一种通过网络转发所有 batch_size×num_variant 图像并将所有变体组的结果相加的正确方法。
我目前的解决方案涉及展平前两个维度并执行 for 循环来对表示进行求和,但我觉得应该有更好的方法,而且我不确定渐变是否会记住所有内容。