我用 torchvision 下载了 cifar10 数据集
for idx,(img,target) in enumerate(trainloader):
print(trainloader.dataset.data.shape)
print(img.shape)
img = img.squeeze()
img = torch.tensor(img)
img=img.squeeze()
print(img.shape)
img = (img).permute(1, 2, 0)
plt.imshow(img)
plt.show()
if idx==0:break
但是打印结果很奇怪
(50000, 32, 32, 3) torch.Size([1, 3, 224, 224]) torch.Size([3, 224, 224])
each image is size with (3,224,224) but the dataset.data's shape is (32,32,3)<br/>
I want to make the dataset's shape with (500000,224,224,3)
def load_data_cifar10(batch_size=128,test=False):
if not test:
train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=True,
download=True, transform=transform)
else:
train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=False,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=False)
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])