2

我正在 Pytorch 中训练图像分类模型,并使用它们的默认数据加载器来加载我的训练数据。我有一个非常大的训练数据集,所以通常每个班级有几千个样本图像。过去我训练过的模型总共有大约 20 万张图像,没有任何问题。但是我发现当总共有超过一百万张图像时,Pytorch 数据加载器会卡住。

我相信当我打电话时代码正在挂起datasets.ImageFolder(...)。当我 Ctrl-C 时,这始终是输出:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

我认为某处可能存在死锁,但是根据 Ctrl-C 的堆栈输出,它看起来不像在等待锁定。所以后来我认为数据加载器很慢,因为我试图加载更多数据。我让它运行了大约 2 天,但没有任何进展,在加载的最后 2 小时内,我检查了 RAM 使用量保持不变。在过去不到几个小时的时间内,我还能够加载包含超过 20 万张图像的训练数据集。我还尝试将我的 GCP 机器升级为拥有 32 个内核、4 个 GPU 和超过 100GB 的 RAM,但似乎在加载了一定数量的内存后,数据加载器就会卡住。

我很困惑数据加载器在遍历目录时如何卡住,我仍然不确定它是卡住还是非常慢。有什么方法可以改变 Pytortch 数据加载器,使其能够处理超过 100 万张图像进行训练?任何调试建议也值得赞赏!

谢谢!

4

1 回答 1

4

这不是问题DataLoader,而是问题torchvision.datasets.ImageFolder以及它是如何工作的(以及为什么你拥有的数据越多,它的工作就越糟糕)。

它挂在这一行,如您的错误所示:

for root, _, fnames in sorted(os.walk(d)): 

来源可以在这里找到。

潜在的问题是它将每个path和对应的都保存label在 Giant 中list,请参见下面的代码(为简洁起见,删除了一些内容):

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

显然,图像将包含 100 万个字符串(也很长),并且对应int于肯定很多并且取决于 RAM 和 CPU 的类。

不过,您可以创建自己的数据集(前提是您事先更改了图像的名称)这样dataset.

设置数据结构

您的文件夹结构应如下所示:

root
    class1
    class2
    class3
    ...

使用您拥有/需要的课程数量。

现在每个都class应该有以下数据:

class1
    0.png
    1.png
    2.png
    ...

鉴于您可以继续创建数据集。

创建数据集

以下torch.utils.data.Dataset用于PIL打开图像,您可以通过其他方式进行操作:

import os
import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
        self._data = pathlib.Path(root) / folder
        self.klass = klass
        self.extension = extension
        # Only calculate once how many files are in this folder
        # Could be passed as argument if you precalculate it somehow
        # e.g. ls | wc -l on Linux
        self._length = sum(1 for entry in os.listdir(self._data))

    def __len__(self):
        # No need to recalculate this value every time
        return self._length

    def __getitem__(self, index):
        # images always follow [0, n-1], so you access them directly
        return Image.open(self._data / "{}.{}".format(str(index), self.extension))

现在您可以轻松地创建数据集(假设文件夹结构如上所示:

root = "/path/to/root/with/images"
dataset = (
    ImageDataset(root, "class0", 0)
    + ImageDataset(root, "class1", 1)
    + ImageDataset(root, "class2", 2)
)

您可以根据需要添加任意数量datasets的指定类,循环或其他方式。

最后,torch.utils.data.DataLoader照常使用,例如:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
于 2020-02-11T19:27:26.367 回答