2

我必须将每个图像的对应图像添加到原始 CIFAR 数据集中,旋转 90 度。这个想法是创建一个 RotationDateset,一个扩展 datasets.VisionDataset 的类,它采用 CIFAR 并执行上述操作。

from __future__ import print_function, division
import skimage.io

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls

//org_dataset 是 CIFAR //num_rots 是 4 //transforms 是 transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5,0.5), (0.5, 0.5, 0.5))])

class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
        
        self.samples = org_dataset.data
        self.targets = []
        self.num_rots = num_rots
        self.transforms = transforms

        for k in self.samples:
          self.targets.append(k)

          for i in range(0, self.num_rots):
            tr = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(degrees=90*i),
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            # from PIL import Image
            p_i = Image.fromarray(k)
            te = tr(p_i)
            r_im = torch.reshape(te, (k.shape))
            r_im = np.array(r_im)
            self.targets.append(r_im)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
      imgs = self.targets[index:index + self.num_rots]
      labels = list(range(0, self.num_rots))

      return imgs, labels

这是我最初导入和转换 CIFAR 的方式:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

这是我创建 CIFAR 增强的方法:

cifar_rot = RotDataset(trainset, trainset.transforms, 4)

rot_train, rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),
test_size=0.2,
shuffle=True,
)

train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)

dataloaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=train_sampler_rot)
               , 'val':torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=val_sampler_rot)}

sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}

和模型训练

model_rot = torchvision.models.resnet34(pretrained=False) 

num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations

model_rot.fc = nn.Linear(num_ftrs, output_dim_rot)

model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.SGD(model_rot.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_rot = train_model(model_rot,
                        criterion,
                        optimizer_conv,
                        exp_lr_scheduler,
                        dataloaders_rot,
                        sizes_rot,
                        num_epochs=10)

torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

//问题是当我启动模型时,pythorch会抛出这个错误:

Epoch 0/9
----------
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-61-977dbbbef6fe> in <module>()
     23                         dataloaders_rot,
     24                         sizes_rot,
---> 25                         num_epochs=10)
     26 #Save the best trained model, for later use
     27 torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    394                             _pair(0), self.dilation, self.groups)
    395         return F.conv2d(input, weight, bias, self.stride,
--> 396                         self.padding, self.dilation, self.groups)
    397 
    398     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 32, 32, 3] to have 3 channels, but got 32 channels instead

任何人都可以帮助我吗?提前致谢

4

1 回答 1

0

问题来自您的依赖org_dataset.data,这是一个形状的 numpy 数组(N, 32, 32, 3)(您希望它在哪里(N, 3, 32, 32)

因此,使用 line self.targets.append(k),您在目标列表中放置了不正确的形状。然后,张量te具有正确的形状(感谢ToTensor),但是您将其重新整形为错误的形状

我还想指出,随机变换如RandomRotation通常应用在__getitem__方法中,而不是在__init__. 由于在这些转换中会生成随机数,因此您希望每个时期都生成新样本,以便拥有几乎无限的数据集和样本。我实际上不确定您是否了解 RandomRotation 的作用:它通过随机旋转来旋转输入张量,您只需指定可能角度的范围。因此,完全有可能应用参数 180 (i=2) 的“旋转”将产生几乎不变的张量。我看到你试图预测i之后的价值,它很可能不起作用。您可能想改用torch.rot90

除此之外,既然你已经申请了ToTensorand Normalizein RotDataset,你当然不需要他们 in CIFAR10

最后评论:我真的不明白您为什么要__getitem返回张量(和标签)列表。我将在下面的代码中保持这种方式,但看起来它最终会破坏某些东西。

所以,这里是你将如何更正你的代码:

class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
    
        # Let's buffer the underlying dataset, we will sample   
        # from it on the fly
        self.dataset = org_dataset
        self.num_rots = num_rots
        # You did not use this attribute previously, probably a mistake
        # It will now be applied in the __getitem__
        self.transforms = transforms
        
    def __len__(self):
        # Typical front dataset : size is the same as the 
        # underlying dataset size
        return len(self.dataset)

    def __getitem__(self, index):
        # sampling from CIFAR10
        sample = self.dataset[index]
        # Because you want to return a list
        imgs = []
        for i in range(0, self.num_rots):
            # Creating the corresponding rotation
            rotation = torchvision.transforms.RandomRotation(degrees=90*i)
            # Applying rotation, followed by other transforms (toTensor, Normalize...)
            transformed = self.transform(rotation(sample))
            imgs.append(transformed)

        # Cleaner way to generate your range : 
        labels = np.arange(self.num_rots)

        return imgs, labels

# transform=None, since we will apply them in RotDataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
# The transforms to call in RotDataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cifar_rot = RotDataset(trainset, transform, 4)

# using torch's random split to remove dependency on sklearn
from torch.utils.data import random_split
test_size = 0.2*len(cifar_rot)
rot_train, rot_val= random_split(cifar_rot, [len(cifar_rot)-test_size, test_size])
于 2021-05-04T09:19:34.707 回答