0

我将原始数据图像保存在单独的 CSV 文件中(文件中的每个图像)。我想使用 PyTorch 对它们进行 CNN 训练。我应该如何加载适合用作 CNN 输入的数据?(另外,它是 1 通道,图像网络的输入默认为 RGB)

4

1 回答 1

0

顾名思义,PyTorch 的 DataLoader 只是一个实用程序类,可帮助您并行加载数据、构建批处理、随机播放等,您需要的是自定义 Dataset 实现。

忽略存储在 CSV 文件中的图像有点奇怪的事实,你只需要这样的东西:

from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):

    def __init__(self, path: Path, ...):
        # do some preliminary checks, e.g. your path exists, files are there...
        assert path.exists()
        ...
        # retrieve your files in some way, e.g. glob
        self.csv_files = list(glob.glob(str(path / "*.csv")))

    def __len__(self) -> int:
        # this lets you know len(dataset) once you instantiate it
        return len(self.csv_files)


    def __getitem__(self, index: int) -> Any:
        # this method is called by the dataloader, each index refers to
        # a CSV file in the list you built in the constructor
        csv = self.csv_files[index]
        # now do whatever you need to do and return some tensors
        image, label = self.load_image(csv)
        return image, label

就是这样,或多或少。然后,您可以创建数据集,将其传递给数据加载器并迭代数据加载器,例如:

dataset = CustomDataset(Path("path/to/csv/files"))
train_loader = DataLoader(dataset, shuffle=True, num_workers=8,...)

for batch in train_loader:
    ...
于 2021-08-07T14:23:16.623 回答