我正在尝试使用 MNIST 数据集。torchvision.datasets
它似乎是作为N x H x W (uint8)
(批量尺寸、高度、宽度)张量提供的。然而,所有用于处理图像的 pytorch 类(例如Conv2d
)都需要一个N x C x H x W (float32)
张量,其中C
是颜色通道的数量。我尝试添加添加ToTensor
变换,但没有添加颜色通道。
有没有办法torchvision.transforms
用来添加这个额外的维度?对于 rawtensor
我们可以做.unsqueeze(1)
,但这看起来不是一个非常优雅的解决方案。我只是想以“正确”的方式来做。
这是失败的转换。
import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])