1

当我试图弄清楚torchvision.datasets.cifar.CIFAR10里面是什么时,我做了一些简单的代码

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                    download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))

但是,我在尝试时遇到了一些错误

print(trainset[:10])

错误信息是

TypeError: Cannot handle this data type

我想知道为什么我可以使用trainset[1],但不能trainset[:10]

4

2 回答 2

2

CIFAR10 不支持切片,这就是您收到该错误的原因。如果您想要前 10 个,则必须这样做:

print([trainset[i] for i in range(10)])

更多信息

您可以索引 CIFAR10 类的实例的主要原因是该类实现了__getitem__()功能。

所以,当你打电话时,trainset[i]你实际上是在打电话trainset.__getitem__(i)

现在,在 python3 中,切片表达式也通过切片表达式作为切片对象__getitem__()传递给的位置进行处理。__getitem__()

所以,trainset[2:10]等价于trainset.__getitem__(slice(2, 10))

而且由于被传递给的两种不同类型的对象__getitem__ 预计会做完全不同的事情,因此您必须明确地处理它们。

不幸的是,它不是,从__getitem__CIFAR10 类的方法实现中可以看出:

def __getitem__(self, index):
    if self.train:
        img, target = self.train_data[index], self.train_labels[index]
    else:
        img, target = self.test_data[index], self.test_labels[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img)

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target
于 2017-07-20T23:51:51.803 回答
0

除了https://stackoverflow.com/a/45226879/7924573 entrophys 答案,我建议使用torch.utils.data.dataset.random_split例如:方式:

train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)
validloader = DataLoader(valid_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)

来源:https ://yimjiyoung.github.io/2020/02/13/How-to-split-dataset-into-train-and-validation-set-in-pytorch/

于 2020-12-14T16:44:11.790 回答