我构建了一个 Siamese Network 对四类书法风格进行分类。Siamese Network 有两个分支,损失函数由三部分组成:两个分支的两个分类损失分别是 CrossEntropyLoss,和一个带有权重的 ContrastiveLoss;问题是第一个分支的分类准确率随着训练epoch的增加而提高,可以达到95%甚至更高。但是,另一个分支几乎没有变化,在50%。问题出在哪里?
class TwinData(data.Dataset):
def __init__(self, imgdataset, trans=None):
self.imgdataset = imgdataset
self.trans = trans
def __getitem__(self, index):
# print(self.imgdataset.imgs[0])
img1_t = self.imgdataset.imgs[index]
same = random.randint(0, 1)
if same:
while True:
img2_t = random.choice(self.imgdataset.imgs)
if img1_t[1] == img2_t[1]:
break
else:
while True:
img2_t = random.choice(self.imgdataset.imgs)
if img1_t[1] != img2_t[1]:
break
img1 = Image.open(img1_t[0])
img2 = Image.open(img2_t[0])
img1 = img1.convert("L")
img2 = img1.convert("L")
if self.trans:
img1 = self.trans(img1)
img2 = self.trans(img2)
return img1, img2, img1_t[1], img2_t[1], torch.from_numpy(np.array(int(img1_t[1] != img2_t[1]), dtype=np.float32))
def __len__(self):
return len(self.imgdataset.imgs)
class SiameseNet(nn.Module):
def __init__(self):
super(SiameseNet, self).__init__()
self.net = BasicNet()
def forward(self, x1, x2):
output1 = self.net(x1)
output2 = self.net(x2)
return output1, output2
def get_feature(self, x):
x = self.net(x)
return x
我在文件中定义了 BasicNet 并没有粘贴到这里。
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
distance = F.pairwise_distance(output1, output2, keepdim=True)
loss = torch.mean((1 - label) * torch.pow(distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
return loss
for epoch in range(EPOCH):
print('*' * 30, 'epoch {}'.format(epoch + 1), '*' * 30)
net.train()
running_loss, running_acc_1, running_acc_2 = 0.0, 0.0, 0.0
for i, (img1, img2, label1, label2, target) in enumerate(twin_train_dataloader):
img1, img2, label1, label2, target = img1.cuda(), img2.cuda(), label1.cuda(), label2.cuda(), target.cuda()
optimizer.zero_grad()
outputs_1, outputs_2 = net(img1, img2)
loss_1 = criterion_1(outputs_1 + EPSILON, label1)
loss_2 = criterion_2(outputs_2 + EPSILON, label2)
twin_loss = criterion_3(outputs_1, outputs_2, target)
loss = loss_1 + loss_2 + LAMBDA * twin_loss
running_loss += loss.item() * target.size(0)
_, preds_1 = torch.max(outputs_1, 1)
_, preds_2 = torch.max(outputs_2, 1)
num_correct_1 = (preds_1 == label1).sum()
num_correct_2 = (preds_2 == label2).sum()
running_acc_1 += num_correct_1.item()
running_acc_2 += num_correct_2.item()
loss.backward()
optimizer.step()