2

因此,为了简单起见,我有这行代码来加载来自两个名为“0”和“1”的类的图像数据集:

train_data = torchvision.datasets.ImageFolder(os.path.join(TRAIN_DATA_DIR), train_transform)

然后我以这种方式准备要与我的模型一起使用的加载器:

train_loader = torch.utils.data.DataLoader(train_data, TRAIN_BATCH_SIZE, shuffle=True)

所以现在每个图像都与一个类相关联,我想要做的是获取每个图像并在这两行代码之间对其应用转换,假设旋转四个度数之一:0、90、180、270 ,并将该信息添加为四个类的附加标签:0、1、2、3。最后,我希望数据集包含旋转后的图像,并将两个值的列表作为它们的标签:图像的类和应用旋转。

我试过了,没有错误,但是如果我尝试打印标签,数据集保持不变:

for idx,label in enumerate(train_data.targets):
    train_data.targets[idx] = [label, 1]

有没有一种很好的方法可以通过直接修改train_data而不需要自定义数据集来做到这一点?

4

1 回答 1

0

有没有一种很好的方法可以通过直接修改 train_data 而不需要自定义数据集来做到这一点?

不,没有。如果你想使用datasets.ImageFolder,你必须接受它有限的灵活性。实际上,ImageFolder它只是 的一个子类DatasetFolder,这几乎就是自定义数据集。您可以在其源代码中看到以下部分__getItem__

if self.transform is not None:
    sample = self.transform(sample)
if self.target_transform is not None:
    target = self.target_transform(target)

这使您想要的成为不可能,因为您预期的变换应该同时修改图像和目标,这在此处独立完成。

所以,首先让你的子类Dataset类似于DatasetFolder,然后简单地实现你自己的变换,它同时接收图像和目标并返回它们的变换值。这只是您可以拥有的转换类的一个示例,然后需要将其组合成一个函数调用:

class RotateTransform(object):
    def __call__(self, image, target):
        # Rotate the image randomly and adjust the target accordingly
        #...

        return image, target

如果这对您的情况来说太麻烦了,那么您拥有的最佳选择是@jchaykow 提到的,即在运行代码之前简单地修改您的文件。

于 2019-10-23T20:14:58.163 回答