0

我有两个问题:

Q1:我想知道向 untet 网络提供训练数据的最佳方式是什么:

  1. 一次发送一名患者,其中每个体积为 160x3x192x192
  2. 从 k 个患者中发送随机切片

Q2:目前我做了第一个选项,但没有收到任何好的结果。我得到一个振荡的骰子分数。例如,骰子损失从 0.99 开始下降到 0.8,峰值下降到 8,并且模式重复。有没有人回答为什么会发生这种情况?

代码:

class main:
def __init__(self, args):
    self.args = args
    self.train_loader = None
    self.in_channel = None
    self.out_channel = None



def _config_dataloader(self):
    print("Starting configuration of the dataset")
    print("Collecting validation and training set")


    validation_mode = "val/"
    training_mode = "train/"

    collect = Get_mean_std(self.args.path + training_mode)
    mean,std = collect(self.args.k)
     
    mean_flair = mean["FLAIR"]
    mean_t1 = mean["T1"]

    std_flair = std["FLAIR"]
    std_t1 = std["T1"]


    train_dataset = MSdataset(self.args.path + training_mode, composed_transforms = [
                        normalize(z_norm = True, mean = mean_flair, std = std_flair),
                        normalize(z_norm = True, mean = mean_t1, std = std_t1),
                        add_channel(depth = self.args.depth), 
                        ToTensor()]
                        )
    
    validation_dataset = MSdataset(self.args.path + validation_mode, composed_transforms = [
                        normalize(z_norm = True, mean = mean_flair, std = std_flair),
                        normalize(z_norm = True, mean = mean_t1, std = std_t1),
                        add_channel(depth = self.args.depth), 
                        ToTensor()]
                        )
    

    
    train_loader = DataLoader(train_dataset, 
                              self.args.batch_size, 
                              self.args.shuffle)
    
    validation_loader = DataLoader(validation_dataset, 
                              self.args.batch_size-1, 
                              self.args.shuffle)
    
    print("Data collected. Returning dataloaders for training and validation set")
    return train_loader, validation_loader

def __call__(self, is_train = False):
    train_loader, validation_loader = self._config_dataloader()
    
    complete_data = {"train": train_loader, "validation":validation_loader }

    device = torch.device("cpu" if not torch.cuda.is_available() else self.args.device)

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    unet.to(device)
    
    optimizer = optim.Adam(unet.parameters(), lr=self.args.lr)
    dsc_loss = DiceLoss()

    loss_train = []
    loss_valid = []

    print("Starting training process. Please wait..")
    sub_batch_size = 14 
    for current_epoch in tqdm(range(self.args.epoch),total= self.args.epoch):

        for phase in ["train", "validation"]:

            if phase == "train":
                unet.train()
            
            if phase == "validation":
                unet.eval()

            for i, data_set_batch in enumerate(complete_data[phase]):
                data_dict = data_set_batch
                X, mask = data_dict["volume"], data_dict["mask"]
                X, mask = (X.to(device)).float(), mask.to(device)
                B,D,C,H,W = X.shape #
             
                mask =mask.reshape((B*D,H,W)) 
                X = X.reshape((B*D,C,H,W)) 
  
                loss_depths = 0 # Nulle ut depth loss
                with torch.set_grad_enabled(is_train):

                    for sub_batches in tqdm(range(0,X.shape[0]-sub_batch_size)): 
                
                  
                        predicted = unet(X[sub_batches: sub_batches + sub_batch_size,:,:,:])
                        loss = dsc_loss(predicted.squeeze(1), mask[sub_batches: sub_batches + sub_batch_size,:,:])
                      
                        if phase == "train":
                       
                            loss_depths = loss_depths + loss
                        if phase == "validation":
                            continue
                if phase == "train":
              
                    loss_train.append(loss_depths)
                    loss_depths.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                 

    print("Training and validation is done. Exiting program and returning loss")  
    return loss_train

请注意,我还没有完全实现验证部分,我只是想先看看网络是如何学习的。谢谢!

4

0 回答 0