import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
learning_rate = 0.001
training_epochs = 15
batch_size = 100
mnist_train = dsets.MNIST(root='MNIST_data/', # Specify download path
train=True, # Specify True to download as training data
transform=transforms.ToTensor(), # Convert to tensor
download=True)
mnist_test = dsets.MNIST(root='MNIST_data/', # Specify download path
train=False, # If false is specified, download as test data
transform=transforms.ToTensor(), # Convert to tensor
download=True)
这是使用CNN加载MNIST数据分类代码数据的部分
在我参考的书中,据说只提到那部分就可以看到训练集和测试集中有多少特定的数字数据。
例如,您能说出该训练或测试集中有多少“5”数据吗?
只知道可以通过mnist_train.train_data或者mnist_train.train_labels等方式访问数据张量,不知道能想到如何知道具体数值数据的个数。帮助