我正在实现一个具有多标签预测的分类器。目前只有5个班。以下是代码片段。到目前为止,我没有任何问题。
但我想进一步扩展它,共享相关任务之间的梯度(由编码任务之间相关性的矩阵W约束)。例如:有五个不同的类a、b、c、d、e。我计算了 a 错误分类的梯度并更新了参数。从矩阵W我发现任务e也是相关的。我还想用计算的梯度更新任务e的参数,同时用W编码的速率预测a。这里W是预先计算和固定的。
我不知道如何在 PyTorch 中做到这一点。如果有人可以帮助我,我将不胜感激。
class Classifier(nn.Module):
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Classifier, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv = nn.Conv2d(curr_dim, 5, kernel_size=kernel_size, bias=False)
def forward(self, x):
h = self.main(x)
out_cls = self.conv2(h)
return out_cls.view(out_cls.size(0), out_cls.size(1))