1

python 模块链接器有一个介绍,它使用其神经网络从 MNIST 数据库中识别手写数字

假设一个特定的手写数字D.png被标记为3。我习惯于标签以数组的形式出现,如下所示:

label = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]

但是,chainer使用整数标签代替:

label = 3

数组标签对我来说更直观,因为输出预测也是一个数组。在不处理图像的神经网络中,我希望灵活地将标签指定为特定数组。

我在下面的链接器介绍中直接包含了代码。如果您通过数据集进行解析traintest请注意所有标签都是整数而不是浮点数。

我将如何使用数组作为标签而不是整数来运行训练/测试数据?

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

class MLP(Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)    # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        y = self.l3(h2)
        return y

train, test = datasets.get_mnist()

train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

model = L.Classifier(MLP(100, 10))  # the input size, 784, is inferred
optimizer = optimizers.SGD()
optimizer.setup(model)

updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (20, 'epoch'), out='result')

trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
4

1 回答 1

1

分类器接受包含图像或其他数据的元组作为数组(float32)和标签作为 int。这是 chainer 的约定以及它在那里的工作方式。如果你打印你的标签,你会看到你得到了一个 dtype int 的数组。图像/非图像数据和标签都将在数组中,但 dtype 分别为 float 和 int。

所以回答你的问题:你的标签本身是数组格式,带有 dtype int(因为它应该是标签)。

如果您希望标签是 0 和 1 而不是 1 到 10,请使用 One Hot Encoding ( https://blog.cambridgespark.com/robust-one-hot-encoding-in-python-3e29bfcec77e )。

于 2019-01-29T06:05:20.380 回答