1

我正在 Pytorch 中测试 MNIST 数据集,在我对 X 数据应用转换后,DataLoader 似乎将所有值置于原始顺序之外,可能会打乱训练步骤。

我的转换是将所有值除以 255。应该注意到转换本身不会改变位置,如第一个散点图所示。但是在将数据传递给 DataLoader 并将其检索回来之后,它们就出现了故障。如果我不进行任何转换,一切都很好(未显示)。beforeafter1(除以 255/DataLoader 之前)和after2(除以 255/DataLoader 之后)(也未显示)之间的值分布相同,只是顺序似乎受到影响。

import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

before = train.data[0]

train.data = train.data.float()/255
after1 = train.data[0]

train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')

我知道这可能不是处理这些数据的最佳方法(transforms.Normalize应该解决它),但我真的很想了解正在发生的事情。

谢谢!

4

2 回答 2

1

所以......我在 Pytorch 的 GitHub 页面上发布了同样的问题,他们回答了以下问题:

它与数据加载器无关。您正在弄乱特定数据集对象的属性,但是,该__getitem__对象的实际功能更多: https ://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92

特别是这一行 ( mode='L') 假定uint8输入。由于您将其替换为浮点数,因此是错误的。

然后我想首选的方法是在我的代码一开始就准备数据集时应用转换。

于 2019-09-14T20:52:55.667 回答
0

原来我没有测试你写的代码。改写原文:

import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

dl = DataLoader(train)

images = dl.dataset.data.float()/255
labels = dl.dataset.targets

train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))

before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]

# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')
于 2019-09-13T15:16:05.977 回答