1

我正在使用来自 Pytorch (Resnet 18,34,50) 的预训练模型来对图像进行分类。在训练过程中,训练中出现了一个奇怪的周期性,如下图所示。有人已经有类似的问题了吗?为了处理过度拟合,我在预处理中使用了数据增强。当使用 SGD 作为具有以下参数的优化器时,我们获得了这种图:

  • 标准:NLLoss()
  • 学习率:0.0001
  • 时代:40
  • 每 40 次迭代打印一次

SGD 训练与验证损失

我们还尝试将 adam 和 Adam bound 作为优化器,但观察到相同的周期性。

提前感谢您的回答!

这是代码:

def train_classifier():
    start=0
    stop=0
    start = timeit.default_timer()
    epochs = 40
    steps = 0
    print_every = 40

    model.to('cuda')
    epo=[]
    train=[]
    valid=[]
    acc_valid=[]
    for e in range(epochs):
        print('Currently running epoch',e,':')
        model.train()
    
        running_loss = 0
    
        for images, labels in iter(train_loader):
        
            steps += 1
        
            images, labels = images.to('cuda'), labels.to('cuda')
        
            optimizer.zero_grad()
        
            output = model.forward(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
        
            running_loss += loss.item()
        
            if steps % print_every == 0:
                
                model.eval()
                
                # Turn off gradients for validation, saves memory and computations
                with torch.no_grad():
                    validation_loss, accuracy = validation(model, val_loader, criterion)
            
                print("Epoch: {}/{}.. ".format(e+1, epochs),
                      "Training Loss: {:.3f}.. ".format(running_loss/print_every),
                      "Validation Loss: {:.3f}.. ".format(validation_loss/len(val_loader)),
                      "Validation Accuracy: {:.3f}".format(accuracy/len(val_loader)))
                stop = timeit.default_timer()
                print('Time: ', stop - start)
                acc_valid.append(accuracy/len(val_loader))
                train.append(running_loss/print_every)
                valid.append(validation_loss/len(val_loader))
                epo.append(e+1)
                running_loss = 0
                model.train()
    return train,epo,valid,acc_valid
4

0 回答 0