2

我使用 PyTorchDataLoader和制作了一个数据集Imagefolder,我的数据集类有两个 ImageFolder 数据集。这两个数据集是成对的(原始图像和地面实况图像)。我想将这些提供给 PyTorch 神经网络。

数据集类:

class bsds_dataset(Dataset):
    def __init__(self, ds_main, ds_energy):
        self.dataset1 = ds_main
        self.dataset2 = ds_energy

    def __getitem__(self, index):
        x1 = self.dataset1[index]
        x2 = self.dataset2[index]

        return x1, x2

    def __len__(self):
        return len(self.dataset1)

我正在使用 Imagefolder 加载图像:

original_imagefolder = './images/whole'
target_imagefolder = './results/whole'

original_ds = ImageFolder(original_imagefolder, 
transform=transforms.ToTensor())
energy_ds = ImageFolder(target_imagefolder, transform=transforms.ToTensor())

dataset = bsds_dataset(original_ds, energy_ds)
loader = DataLoader(dataset, batch_size=16)

然后我尝试分批迭代:

for i, x, y in enumerate(loader):
    print(x)

发生了这个错误:

RuntimeError:无效参数 0:张量的大小必须匹配,但维度 0 除外。在 ..\aten\src\TH/generic/THTensor.cpp:711 的维度 2 中得到 321 和 481

数据集是 BSDS500:

https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html

数据集中的所有图像都是 481x321 或 321x481 像素。我认为需要进行一些转换,但我不想拆除图像并拉伸它们。

完整追溯:

C:\Anaconda3\envs\torchgpu\lib\site-packages\ipykernel_launcher.py:77: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_.
C:\Anaconda3\envs\torchgpu\lib\site-packages\ipykernel_launcher.py:78: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-42-4c4ba0a13c32> in <module>
      5 optimizer = optim.SGD(model.parameters(), lr=0.001)
      6 for epoch in range(epochs):
----> 7     for i, batch in enumerate(loader):
      8         print(batch)

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    558         if self.num_workers == 0:  # same-process loading
    559             indices = next(self.sample_iter)  # may raise StopIteration
--> 560             batch = self.collate_fn([self.dataset[i] for i in indices])
    561             if self.pin_memory:
    562                 batch = _utils.pin_memory.pin_memory_batch(batch)

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\_utils\collate.py in <listcomp>(.0)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\_utils\collate.py in <listcomp>(.0)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

C:\Anaconda3\envs\torchgpu\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
     41             storage = batch[0].storage()._new_shared(numel)
     42             out = batch[0].new(storage)
---> 43         return torch.stack(batch, 0, out=out)
     44     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
     45             and elem_type.__name__ != 'string_':

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 321 and 481 in dimension 2 at ..\aten\src\TH/generic/THTensor.cpp:711

4

0 回答 0