我是 PyTorch 的新手,我正在使用经典的 MNIST 数据集进行图像分类。在拟合模型时,我遇到了错误:
NotImplementedError: uint8
我正在使用 fastai 库的类作为所有训练和验证数据的包装器以及一个非常基本的单层神经网络。我正在使用的代码如下:
from keras.datasets import mnist
import matplotlib.pyplot as plt
from fastai.metrics import *
from fastai.model import *
from fastai.dataset import *
import torch.nn as nn
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
net = nn.Sequential(
nn.Linear(784,10),
nn.Softmax()).cuda()
md = ImageClassifierData.from_arrays('/data/mnist',
(x_train,y_train),
(x_valid, y_valid))
loss = nn.NLLLoss()
metrics = [accuracy]
opt=optim.SGD(net.parameters(), 1e-1, momentum=0.9, weight_decay=1e-3)
fit(net, md, n_epochs=3, crit=loss, opt=opt, metrics=metrics)
有人可以告诉这个错误是什么以及它的解决方案吗?