4

我需要BatchSampler在 pytorch中使用 a 而不是多次DataLoader调用数据集(远程数据集,每个查询都很昂贵)。我无法理解如何将批处理采样器与任何给定的数据集一起使用。__getitem__

例如

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, idx):
        return self.ddf[idx] --------> This is as expensive as a batch call

    def get_batch(self, batch_idx):
        return self.ddf[batch_idx]

my_loader = DataLoader(MyDataset(remote_ddf), 
           batch_sampler=BatchSampler(Sampler(), batch_size=3))

我不明白的事情,在网上或torch docs中都没有找到任何示例,是如何使用我的get_batch函数而不是 __getitem__ 函数。
编辑:按照 Szymon Maszke 的回答,这是我尝试过的,但\_\_get_item__每次调用都会获取一个索引,而不是大小列表batch_size

class Dataset(Dataset):

    def __init__(self):
       ...

    def __len__(self):
        ...

    def __getitem__(self, batch_idx):  ------> here I get only one index
        return self.wiki_df.loc[batch_idx]


loader = DataLoader(
                dataset=dataset,
                batch_sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )
4

1 回答 1

5

您不能使用get_batch代替,__getitem__而且我认为这样做没有意义。

torch.utils.data.BatchSamplerSampler()从您的实例中获取索引(在这种情况下3)并返回它,list以便可以在您的MyDataset __getitem__方法中使用这些索引(检查源代码,大多数采样器和与数据相关的实用程序在您需要时很容易遵循)。

我假设您self.ddf支持列表切片(例如self.ddf[[25, 44, 115]],正确返回值并且只使用一个昂贵的调用)。在这种情况下,只需切换get_batch__getitem__就可以了。

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, batch_idx):
        return self.ddf[batch_idx] -> batch_idx is a list

编辑:您必须指定batch_samplersampler,否则批次将分为单个索引。这应该没问题:

loader = DataLoader(
    dataset=dataset,
    # This line below!
    sampler=BatchSampler(
        SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
    ),
    num_workers=self.hparams.num_data_workers,
)
于 2020-04-27T12:13:15.610 回答