0

我正在尝试实现一个自定义的 IterableDataset,在其中我从文件中读取单词,获取它们的唯一 ID,收集它们并将它们批量返回。

import os
import torch
import tqdm
from torch.utils.data import IterableDataset, DataLoader
import vocab  # C++ class bound to python with pybind11

class MyIterableDataset(IterableDataset):
    def __init__(self, file_path, v, num_workers=4):
        super(MyIterableDataset).__init__()
        self.file_path = file_path
        self.file_size = os.stat(file_path).st_size
        self.v = v  # vocab object, bound from a C++ class with pybind11

        chunk_size = self.file_size // num_workers
        start = 0
        end = chunk_size
        bonus = self.file_size - chunk_size * num_workers
        if (bonus > 0):
            end = chunk_size + 1
            bonus -= 1
        self.endpoints = [(start, end)]
        for i in range(1, num_workers):
            start = end
            if (bonus > 0):
                end += chunk_size + 1
                bonus -= 1
            else:
                end += chunk_size
            self.endpoints.append((start, end))

    def read_word(self, f):
        ch = ''
        word = ""
        while True:
            ch = f.read(1)
            if not ch:
                return ''
            if (str.isspace(ch)):
                if len(word) > 0:
                    break
                if (ch == '\n'):
                    return "\n"
                else:
                    continue
            word += ch
        return word

    def parse_file(self, start, words_to_read, id):
        words_read = 0
        f = open(self.file_path, "r")
        f.seek(start, 0)
        if id > 0:
            while True:
                ch = f.read(1)
                if not ch or str.isspace(ch):
                    break
                start += 1
            f.seek(start, 0)
        while True:                
            word = self.read_word(f)
            if word and word != "\n":
                wid = self.v.word2id(word)
                if wid != -1:
                    words_read += 1
                    yield wid  # if I yield 'word' instead, everything works. You can also yield 1 and you get the error
            if words_read >= words_to_read or not word:
                break
        f.close()
                
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        words_to_read = self.v.get_train_words() // worker_info.num_workers
        start, end = self.endpoints[worker_info.id]
        return self.parse_file(start, words_to_read, worker_info.id)

在我的数据集上运行 DataLoader 时

num_workers = 7
v = vocab.Vocab("./text8")  # Vocab is a C++ class bound to python with pybind11
ds = MyIterableDataset(file_path=file_path, v=v, num_workers=num_workers)
wids = [j for _, j in tqdm.tqdm(enumerate(DataLoader(ds, num_workers=num_workers, batch_size=10)))]

每当我产生单词 id时,我都会收到以下错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-04575fb9c982> in <module>
      2 
      3 t0 = time.time()
----> 4 tokens = [j for _, j in tqdm.tqdm(enumerate(DataLoader(ds, num_workers=num_workers, batch_size=10)))]
      5 print()
      6 print(time.time() - t0)

<ipython-input-18-04575fb9c982> in <listcomp>(.0)
      2 
      3 t0 = time.time()
----> 4 tokens = [j for _, j in tqdm.tqdm(enumerate(DataLoader(ds, num_workers=num_workers, batch_size=10)))]
      5 print()
      6 print(time.time() - t0)

~/miniconda3/envs/word2gm/lib/python3.8/site-packages/tqdm/std.py in __iter__(self)
   1165 
   1166         try:
-> 1167             for obj in iterable:
   1168                 yield obj
   1169                 # Update and possibly print the progressbar.

~/miniconda3/envs/word2gm/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/miniconda3/envs/word2gm/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1066 
   1067             assert not self._shutdown and self._tasks_outstanding > 0
-> 1068             idx, data = self._get_data()
   1069             self._tasks_outstanding -= 1
   1070             if self._dataset_kind == _DatasetKind.Iterable:

~/miniconda3/envs/word2gm/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _get_data(self)
   1032         else:
   1033             while True:
-> 1034                 success, data = self._try_get_data()
   1035                 if success:
   1036                     return data

~/miniconda3/envs/word2gm/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    897             except OSError as e:
    898                 if e.errno == errno.EMFILE:
--> 899                     raise RuntimeError(
    900                         "Too many open files. Communication with the"
    901                         " workers is no longer possible. Please increase the"

RuntimeError: Too many open files. Communication with the workers is no longer possible. Please increase the limit using `ulimit -n` in the shell or change the sharing strategy by calling `torch.multiprocessing.set_sharing_strategy('file_system')` at the beginning of your code

而如果我让这个一切正常!有人可以帮我理解为什么会发生这种情况吗?

4

0 回答 0