解决此问题的一种可能方法是使用批处理采样器并为您的数据加载器实现一个collate_fn
,它将对您的批处理元素执行动态填充。
拿这个基本数据集:
class DS(Dataset):
def __init__(self, files):
super().__init__()
self.len = len(files)
self.files = files
def __getitem__(self, index):
return self.files[index]
def __len__(self):
return self.len
用一些随机数据初始化:
>>> file_len = np.random.randint(0, 100, (16*6))
>>> files = [np.random.rand(s) for s in file_len]
>>> ds = DS(files)
首先定义您的批处理采样器,这本质上是一个可迭代的返回批处理索引,数据加载器将使用这些索引从数据集中检索元素。正如您所解释的,我们可以对长度进行排序并从这种排序中构造不同的批次:
>>> batch_size = 16
>>> batches = np.split(file_len.argsort()[::-1], batch_size)
我们应该有长度彼此接近的元素。
我们可以实现一个collate_fn
函数来组装批处理元素并集成动态填充。这基本上是在数据集和数据加载器之间放置一个额外的用户定义层。目标是找到批次中最长的元素,并用正确数量的0
s 填充所有其他元素:
def collate_fn(batch):
longest = max([len(x) for x in batch])
s = np.stack([np.pad(x, (0, longest - len(x))) for x in batch])
return torch.from_numpy(s)
然后你可以初始化一个数据加载器:
>>> dl = DataLoader(dataset=ds, batch_sampler=batches, collate_fn=collate_fn)
并尝试迭代,如您所见,我们得到了长度递减的批次:
>>> for x in dl:
... print(x.shape)
torch.Size([6, 99])
torch.Size([6, 93])
torch.Size([6, 83])
torch.Size([6, 76])
torch.Size([6, 71])
torch.Size([6, 66])
torch.Size([6, 57])
...
这种方法有一些缺陷,例如,元素的分布总是相同的。这意味着您将始终以相同的外观顺序获得相同的批次。这是因为此方法基于数据集中元素的长度排序,因此批次的创建没有可变性。您可以通过改组批次来减少这种影响(例如,通过包装batches
在 a 中RandomSampler
)。但是,正如我所说,批次的内容将在整个培训过程中保持不变,这可能会导致一些问题。
请注意,batch_sampler
在您的数据加载器中使用的是互斥选项batch_size
,shuffle
和sampler
!