因此,为了简单起见,我有这行代码来加载来自两个名为“0”和“1”的类的图像数据集:
train_data = torchvision.datasets.ImageFolder(os.path.join(TRAIN_DATA_DIR), train_transform)
然后我以这种方式准备要与我的模型一起使用的加载器:
train_loader = torch.utils.data.DataLoader(train_data, TRAIN_BATCH_SIZE, shuffle=True)
所以现在每个图像都与一个类相关联,我想要做的是获取每个图像并在这两行代码之间对其应用转换,假设旋转四个度数之一:0、90、180、270 ,并将该信息添加为四个类的附加标签:0、1、2、3。最后,我希望数据集包含旋转后的图像,并将两个值的列表作为它们的标签:图像的类和应用旋转。
我试过了,没有错误,但是如果我尝试打印标签,数据集保持不变:
for idx,label in enumerate(train_data.targets):
train_data.targets[idx] = [label, 1]
有没有一种很好的方法可以通过直接修改train_data
而不需要自定义数据集来做到这一点?