嗨,我想将自己的图像添加到 torchvision 中的 CIFAR10 数据集,我该怎么做?
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_data.add # or a workaround!
谢谢
嗨,我想将自己的图像添加到 torchvision 中的 CIFAR10 数据集,我该怎么做?
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_data.add # or a workaround!
谢谢
您可以使用此处CIFAR10
的原始 cifar10 图像创建自定义数据集,或者您仍然可以在新的自定义数据集中使用数据集,然后在方法中添加您的逻辑。
这是一个简单的例子来帮助你:CIFAR10
__getitem__()
class CIFAR10_2(torch.utils.data.Dataset):
def __init__(self, dataset_path='/cifar10', transformations=None, should_download=True):
self.dataset_train = torchvision.datasets.CIFAR10(dataset_path, download=should_download)
self.transformations = transformations
def __getitem__(self, index):
# do as you wish , add your logic here
(img, label) = self.dataset_train[index]
# for transformations for example
if self.transformations is not None:
return self.transformations(img), label
return img, label
def __len__(self):
return len(self.dataset_train)
您可以花哨并为测试、验证等添加逻辑,并做任何您喜欢的事情。