如果您想要一个 DataLoader,您只想为每个样本定义类标签,那么您可以使用torch.data.utils.Subset
该类。尽管它的名字它不一定需要定义数据集的子集。例如
import torch
import torchvision
import torchvision.transforms as T
from itertools import cycle
mnist = torchvision.datasets.MNIST(root='./', train=True, transform=T.ToTensor())
# not sure what "...and so on" implies, but define this list however you like
target_classes = [1, 1, 1, 1, 1, 7, 7, 7, 7, 7, 3, 3, 3, 3, 3]
# create cyclic iterators of indices for each class in MNIST
indices = dict()
for label in torch.unique(mnist.targets).tolist():
indices[label] = cycle(torch.nonzero(mnist.targets == label).flatten().tolist())
# define the order of indices in the new mnist subset based on target_classes
new_indices = []
for t in target_classes:
new_indices.append(next(indices[t]))
# create a Subset of MNIST based on new_indices
mnist_modified = torch.utils.data.Subset(mnist, new_indices)
dataloader = torch.utils.data.DataLoader(mnist_modified, batch_size=1, shuffle=False)
for idx, (x, y) in enumerate(dataloader):
# training loop
print(f'Batch {idx+1} labels: {y.tolist()}')