如果我理解正确,您希望获得 MNIST 图像的整个训练数据集(总共 60000 张图像,每个图像的大小为 1x28x28 数组,颜色通道为 1)作为大小为 (60000, 1, 28, 28) 的 numpy 数组?
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Transform to normalized Tensors
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()
这是结果:
>>> train_dataset_array
array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
...,
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]]], dtype=float32)