我正在尝试构建一个 PyTorch 数据集,它21x512x512
从 shape 的 3D 图像返回切片?x512x512
。我知道有多少张图像,但我不知道每张图像中有多少个切片。因此,我会直观地让__len__()
Dataset 的函数返回我拥有的图像总数。我可以事先在技术上检查所有图像的形状,但数据集可能会随着时间而改变,所以我更喜欢可扩展的软件解决方案。
有了这个,我需要一些功能将图像分成切片(上面提到的大小),并返回这些而不是整个图像。这也不是问题,我有一个功能可以做到这一点。
问题来了。如果我在 Dataset 的函数中添加这个切片功能__getitem__()
,那么我将只能得到每个图像一个切片,因为 PyTorch DataLoader 会认为有len(dataset)
数据点,现在情况不再如此。但我也无法指定正确的样本数量,因为我事先并不知道。
我尝试了一些解决方案:
- 返回一个生成器函数,
__getitem__()
其中每个图像产生切片。这不起作用,因为__getitem__()
需要返回类型list
,tuple
等的东西tensor
。 - 只需返回整个图像并在训练循环中将其分解。这可以工作,但既不好的编程风格(因为我想隐藏数据集中的数据选择)并且与 DataLoader 的批处理不太兼容,因为一个图像可能有 100 个切片,而另一个可能只有 5 个切片. 在这种情况下,从这些图像制作批次将导致只有 5 个批次具有实际的
batch_size
,而其他 95 个批次的样本较少。解决这个问题需要进行一些难看的检查,看看是否需要加载另一个图像,我想再次将其隐藏在数据集中。 - Yield 会导致Dataset 函数出现
for
循环。__getitem__()
这与第 1 点的原因不同:无法在数据集中返回生成器。
简而言之,什么是从 PyTorch 数据集中的 3D 图像加载未知数量切片的干净方法?