1

我对 PyTorch 相当陌生,并且一直在试验 DataLoader 类。当我尝试加载 MNIST 数据集时,DataLoader 似乎在批处理维度之后添加了一个额外的维度。我不确定是什么导致了这种情况发生。

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

if __name__ == '__main__':
    mnist_train = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    first_x = mnist_train.data[0]
    print(first_x.shape)  # expect to see [28, 28], actual [28, 28]

    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200)
    batch_x, batch_y = next(iter(train_loader))  # get first batch
    print(batch_x.shape)  # expect to see [200, 28, 28], actual [200, 1, 28, 28]
    # Where is the extra dimension of 1 from?

任何人都可以阐明这个问题吗?

4

1 回答 1

0

我猜这是输入图像的通道数。所以基本上是

batch_x.shape = Batch-size, No of channels, Height of the image, Width of the image

于 2019-11-28T12:34:26.430 回答