我必须将每个图像的对应图像添加到原始 CIFAR 数据集中,旋转 90 度。这个想法是创建一个 RotationDateset,一个扩展 datasets.VisionDataset 的类,它采用 CIFAR 并执行上述操作。
from __future__ import print_function, division
import skimage.io
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls
//org_dataset 是 CIFAR //num_rots 是 4 //transforms 是 transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5,0.5), (0.5, 0.5, 0.5))])
class RotDataset(datasets.VisionDataset):
def __init__(self, org_dataset, transforms, num_rots):
self.samples = org_dataset.data
self.targets = []
self.num_rots = num_rots
self.transforms = transforms
for k in self.samples:
self.targets.append(k)
for i in range(0, self.num_rots):
tr = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(degrees=90*i),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# from PIL import Image
p_i = Image.fromarray(k)
te = tr(p_i)
r_im = torch.reshape(te, (k.shape))
r_im = np.array(r_im)
self.targets.append(r_im)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
imgs = self.targets[index:index + self.num_rots]
labels = list(range(0, self.num_rots))
return imgs, labels
这是我最初导入和转换 CIFAR 的方式:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
这是我创建 CIFAR 增强的方法:
cifar_rot = RotDataset(trainset, trainset.transforms, 4)
rot_train, rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),
test_size=0.2,
shuffle=True,
)
train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)
dataloaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=train_sampler_rot)
, 'val':torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=val_sampler_rot)}
sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}
和模型训练
model_rot = torchvision.models.resnet34(pretrained=False)
num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations
model_rot.fc = nn.Linear(num_ftrs, output_dim_rot)
model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_rot.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_rot = train_model(model_rot,
criterion,
optimizer_conv,
exp_lr_scheduler,
dataloaders_rot,
sizes_rot,
num_epochs=10)
torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')
//问题是当我启动模型时,pythorch会抛出这个错误:
Epoch 0/9
----------
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-61-977dbbbef6fe> in <module>()
23 dataloaders_rot,
24 sizes_rot,
---> 25 num_epochs=10)
26 #Save the best trained model, for later use
27 torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')
6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
394 _pair(0), self.dilation, self.groups)
395 return F.conv2d(input, weight, bias, self.stride,
--> 396 self.padding, self.dilation, self.groups)
397
398 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 32, 32, 3] to have 3 channels, but got 32 channels instead
任何人都可以帮助我吗?提前致谢