1

假设我有一个 cifar10 的数据加载器,
如果我想从数据加载器中删除一些值并创建一个新的数据加载器,我
应该怎么做?

def load_data_cifar10(batch_size=128,test=False):
    if not test:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=True,
                                                download=True, transform=transform)
    else:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=False,
                                               download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
    print("LOAD DATA, %d" % (len(train_loader)))
    return train_loader
4

1 回答 1

1

您可以使用Subset数据集。这需要另一个数据集作为输入以及一个索引列表来构建一个新的数据集。假设您想要前 1000 个条目,那么您可以这样做

subset_train_dset = torch.utils.data.Subset(train_dset, range(1000))

您还可以使用ConcatDatasetdataset 或组合构建由多个数据集组成的数据集,ConcatDatasetSubset构建您喜欢的任何内容

frankenstein_dset = torch.utils.data.ConcatDataset((
    torch.utils.data.Subset(dset1, range(1000)),
    torch.utils.data.Subset(dset2, range(100)))

在您的情况下,您需要查看实现细节以确定要保留哪些索引,或者您可以编写一些代码来首先遍历原始数据集并保存您要保留的所有索引,然后Subset使用适当的索引定义.

于 2019-12-19T16:59:53.807 回答