0

我用 torchvision 下载了 cifar10 数据集

for idx,(img,target) in enumerate(trainloader):
    print(trainloader.dataset.data.shape)
    print(img.shape)
    img = img.squeeze()
    img = torch.tensor(img)
    img=img.squeeze()
    print(img.shape)
    img = (img).permute(1, 2, 0)
    plt.imshow(img)
    plt.show()
    if idx==0:break

但是打印结果很奇怪

(50000, 32, 32, 3) torch.Size([1, 3, 224, 224]) torch.Size([3, 224, 224])

each image is size with (3,224,224) but the dataset.data's shape is (32,32,3)<br/>
I want to make the dataset's shape with (500000,224,224,3)


def load_data_cifar10(batch_size=128,test=False):
    if not test:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=True,
                                                download=True, transform=transform)
    else:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=False,
                                               download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=False)



transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])
4

0 回答 0