0

我有一个非常基本的 MLP 网络:

def create_gluon_model(num_features, num_classes):
    # Create Model in Gluon
    net = nn.HybridSequential()
    net.add(nn.Dense(num_features, activation="relu"))
    net.add(nn.Dense(1000, activation="relu"))
    net.add(nn.Dense(num_classes))
    #net.hybridize()
    net.initialize(init=init.Xavier(), ctx=mx.cpu())
    return net

我的输入数据形状:(32, 20)
输出形状:(32, 4)
标签形状:(32, 4) num_classes = 4

当我尝试训练时:

def train_vmhnet(net, train_data_loader, valid_data_loader, batch_size=32):
    criterion = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.1})
    # Start the training.
    for epoch in range(1):
        train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0
        tic = time.time()
        for batch_idx, (data, label) in enumerate(train_data_loader):
            data = data.as_in_context(mx.cpu(0))
            label = label.as_in_context(ctx)
            # forward + backward
            with autograd.record():
                output = net(data)
                loss = criterion(output, label)
            loss.backward()
            # update parameters
            trainer.step(data.shape[0])
            # calculate training metrics
            train_loss += loss.mean().asscalar()
            train_acc += acc(output, label)
        print(epoch)
        # calculate validation accuracy
        for batch_idx, (data, label) in enumerate(valid_data_loader):
            data = data.as_in_context(mx.cpu(0))
            valid_acc += acc(net(data), label)
        print(
            "Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec"
            % (
                epoch,
                train_loss / len(dataset_train),
                train_acc / len(dataset_train),
                valid_acc / len(dataset_test),
                time.time() - tic,
            )
        )

即时消息收到以下错误:

mxnet.base.MXNetError: Shape inconsistent, Provided = [32,4], inferred shape=[32,1]

请帮忙

4

1 回答 1

0

经过几个小时的搜索,我终于在这里找到了修复程序。如果您使用一种热编码,请确保在

criterion = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)

这修复了它

于 2020-07-06T13:20:23.437 回答