我对 PyTorch 相当陌生,并且一直在试验 DataLoader 类。当我尝试加载 MNIST 数据集时,DataLoader 似乎在批处理维度之后添加了一个额外的维度。我不确定是什么导致了这种情况发生。
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
if __name__ == '__main__':
mnist_train = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
first_x = mnist_train.data[0]
print(first_x.shape) # expect to see [28, 28], actual [28, 28]
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200)
batch_x, batch_y = next(iter(train_loader)) # get first batch
print(batch_x.shape) # expect to see [200, 28, 28], actual [200, 1, 28, 28]
# Where is the extra dimension of 1 from?
任何人都可以阐明这个问题吗?