1
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torch.nn.functional as F


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

learning_rate = 0.001
training_epochs = 15
batch_size = 100

mnist_train = dsets.MNIST(root='MNIST_data/', # Specify download path
                          train=True, # Specify True to download as training data
                          transform=transforms.ToTensor(), # Convert to tensor
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/', # Specify download path
                         train=False, # If false is specified, download as test data
                         transform=transforms.ToTensor(), # Convert to tensor
                         download=True)

这是使用CNN加载MNIST数据分类代码数据的部分

在我参考的书中,据说只提到那部分就可以看到训练集和测试集中有多少特定的数字数据。

例如,您能说出该训练或测试集中有多少“5”数据吗?

只知道可以通过mnist_train.train_data或者mnist_train.train_labels等方式访问数据张量,不知道能想到如何知道具体数值数据的个数。帮助

4

1 回答 1

1

data您可以分别使用和targets属性访问数据集的数据和标签,以进行拆分。因此,例如,您可以在此处分别使用mnist_train.data和访问训练数据和标签mnist_train.labels

由于该数据集的targets属性是torch.Tensor,因此您可以使用 来计算每个目标的实例数torch.bincount。由于总共有 10 个类,输出将是一个长度为 10 的张量,其中第 i索引指定第 i 个类的数据点的数量。

例子:

>>> mnist_train = dsets.MNIST(root='MNIST_data/', train=True, transforms.ToTensor(), download=True)
>>> mnist_train.targets
tensor([5, 0, 4,  ..., 5, 6, 8])
>>> torch.bincount(mnist_train.targets, minlength=10)
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])

您可以看到第 5 类在训练拆分中有 5,421 个数据点。

于 2021-04-23T23:32:26.630 回答