定义 idx 的是sampler
or batch_sampler
,正如您在此处看到的(开源项目是您的朋友)。在这段代码sampler
(和注释/文档字符串)中,您可以看到和之间的区别batch_sampler
。如果你看这里,你会看到索引是如何选择的:
def __next__(self):
index = self._next_index()
# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
return next(self._sampler_iter)
# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)
# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler
注意这是_SingleProcessDataLoaderIter
实现;你可以在_MultiProcessingDataLoaderIter
这里找到(ofc,使用哪一个取决于num_workers
值,你可以在这里看到)。回到采样器,假设您的数据集不是_DatasetKind.Iterable
并且您没有提供自定义采样器,这意味着您正在使用(dataloader.py#L212-L215):
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
让我们看一下默认的 BatchSampler 是如何构建批次的:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
非常简单:它从采样器获取索引,直到达到所需的 batch_size。
现在的问题是“__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?” 可以通过查看每个默认采样器的工作方式来回答。
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
因此,由于您没有提供任何代码,我们只能假设:
- 您
shuffle=True
在 DataLoader中使用或
- 您正在使用自定义采样器或
- 您的数据集是
_DatasetKind.Iterable