8

我已经建立了一个数据集,我正在对正在加载的图像进行各种检查。然后,我将此 DataSet 传递给 DataLoader。

在我的 DataSet 类中,如果图片未通过我的检查,我将样本返回为 None,并且我有一个自定义 collat​​e_fn 函数,该函数从检索到的批次中删除所有 None 并返回剩余的有效样本。

但是,此时返回的批次可以具有不同的大小。有没有办法告诉 collat​​e_fn 保持采购数据,直到批量大小达到一定长度?

class DataSet():
     def __init__(self, example):
          # initialise dataset
          # load csv file and image directory
          self.example = example
     def __getitem__(self,idx):
          # load one sample
          # if image is too dark return None
          # else 
          # return one image and its equivalent label

dataset = Dataset(csv_file='../', image_dir='../../')

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
    batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
    # I want len(G) = 4
    # so how to sample another dataset entry?
    return torch.utils.data.dataloader.default_collate(batch) 
4

5 回答 5

9

有 2 种 hack 可以用来解决问题,选择一种方法:

通过使用原始批次样本快速选项

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return torch.utils.data.dataloader.default_collate(batch)

否则,只需从数据集中随机加载另一个样本更好的选项

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            batch.append(dataset[np.random.randint(0, len(dataset))])

    return torch.utils.data.dataloader.default_collate(batch)
于 2019-09-11T05:41:58.640 回答
4

这对我有用,因为有时甚至那些随机值也是无。

def my_collate(batch):
    len_batch = len(batch)
    batch = list(filter(lambda x: x is not None, batch))

    if len_batch > len(batch):                
        db_len = len(dataset)
        diff = len_batch - len(batch)
        while diff != 0:
            a = dataset[np.random.randint(0, db_len)]
            if a is None:                
                continue
            batch.append(a)
            diff -= 1

    return torch.utils.data.dataloader.default_collate(batch)
于 2021-05-18T09:39:54.093 回答
3

对于任何希望即时拒绝训练示例的人,与其使用技巧来解决数据加载器的 collat​​e_fn 中的问题,不如简单地使用IterableDataset并编写 __iter__ 和 __next__ 函数,如下所示

def __iter__(self):
    return self
def __next__(self):
    # load the next non-None example
于 2020-01-03T18:13:44.180 回答
1

感谢 Brian Formento 就如何解决它提出了问题并给出了想法。如前所述,用新示例替换坏示例的最佳选项有两个问题:

  1. 新采样的示例也可能被破坏;
  2. 数据集不在范围内。

这是他们两个的解决方案 - 问题 1 通过递归调用解决,问题 2 通过创建 collat​​e 函数的部分函数和数据集固定到位。

import random
import torch


def collate_fn_replace_corrupted(batch, dataset):
    """Collate function that allows to replace corrupted examples in the
    dataloader. It expect that the dataloader returns 'None' when that occurs.
    The 'None's in the batch are replaced with another examples sampled randomly.

    Args:
        batch (torch.Tensor): batch from the DataLoader.
        dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
            Specify it with functools.partial and pass the resulting partial function that only
            requires 'batch' argument to DataLoader's 'collate_fn' option.

    Returns:
        torch.Tensor: batch with new examples instead of corrupted ones.
    """ 
    # Idea from https://stackoverflow.com/a/57882783

    original_batch_len = len(batch)
    # Filter out all the Nones (corrupted examples)
    batch = list(filter(lambda x: x is not None, batch))
    filtered_batch_len = len(batch)
    # Num of corrupted examples
    diff = original_batch_len - filtered_batch_len
    if diff > 0:
        # Replace corrupted examples with another examples randomly
        batch.extend([dataset[random.randint(0, len(dataset))] for _ in range(diff)])
        # Recursive call to replace the replacements if they are corrupted
        return collate_fn_replace_corrupted(batch, dataset)
    # Finally, when the whole batch is fine, return it
    return torch.utils.data.dataloader.default_collate(batch)

但是,您不能将其直接传递给,DataLoader因为 collat​​e 函数应该只有一个参数 - batch。为此,我们使用指定的数据集创建一个偏函数,并将偏函数传递给DataLoader.

import functools
from torch.utils.data import DataLoader


collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
                  batch_size=batch_size,
                  num_workers=num_workers,
                  pin_memory=pin_memory,
                  collate_fn=collate_fn)
于 2021-10-14T23:13:26.030 回答
-1

对于Fast 选项,它有问题。下面是固定版本。

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        batch = batch + batch[:diff] # assume diff < len(batch)
    return torch.utils.data.dataloader.default_collate(batch)
于 2020-11-04T03:10:06.250 回答