1

我正在尝试在 Torch 的 XOR 函数上训练一个简单的测试网络。当我使用 MSECriterion 时它可以工作,但是当我尝试 CrossEntropyCriterion 时它会失败并显示以下错误消息:

/home/a/torch/install/bin/luajit: /home/a/torch/install/share/lua/5.1/nn/THNN.lua:699: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /tmp/luarocks_nn-scm-1-6937/nn/lib/THNN/generic/ClassNLLCriterion.c:31
stack traceback:
    [C]: in function 'v'
    /home/a/torch/install/share/lua/5.1/nn/THNN.lua:699: in function 'ClassNLLCriterion_updateOutput'
    ...e/a/torch/install/share/lua/5.1/nn/ClassNLLCriterion.lua:41: in function 'updateOutput'
    ...torch/install/share/lua/5.1/nn/CrossEntropyCriterion.lua:13: in function 'forward'
    .../a/torch/install/share/lua/5.1/nn/StochasticGradient.lua:35: in function 'train'
    a.lua:34: in main chunk
    [C]: in function 'dofile'
    /home/a/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk
    [C]: at 0x00406670

将其分解为 LogSoftMax 和 ClassNLLCriterion 时,我收到相同的错误消息。代码是:

dataset={};
function dataset:size() return 100 end -- 100 examples
for i=1,dataset:size() do
  local input = torch.randn(2);     -- normally distributed example in 2d
  local output = torch.Tensor(2);
  if input[1]<0 then
      input[1]=-1
  else
      input[1]=1
  end
  if input[2]<0 then
      input[2]=-1
  else
      input[2]=1
  end
  if input[1]*input[2]>0 then     -- calculate label for XOR function
    output[2] = 1;
  else
    output[1] = 1
  end
  dataset[i] = {input, output}
end

require "nn"
mlp = nn.Sequential();  -- make a multi-layer perceptron
inputs = 2; outputs = 2; HUs = 20; -- parameters
mlp:add(nn.Linear(inputs, HUs))
mlp:add(nn.Tanh())
mlp:add(nn.Linear(HUs, outputs))

criterion = nn.CrossEntropyCriterion()
trainer = nn.StochasticGradient(mlp, criterion)
trainer.learningRate = 0.01
trainer:train(dataset)

x = torch.Tensor(2)
x[1] =  1; x[2] =  1; print(mlp:forward(x))
x[1] =  1; x[2] = -1; print(mlp:forward(x))
x[1] = -1; x[2] =  1; print(mlp:forward(x))
x[1] = -1; x[2] = -1; print(mlp:forward(x))
4

1 回答 1

3

MSE 准则是为回归问题设计的。当它用于分类任务时,目标应该是 one-hot 向量。交叉熵/负对数似然标准专门用于分类;因此,无需将目标类显式表示为向量。在torch此类标准的目标中只是分配类的索引(1 到类的数量)。

于 2016-02-08T15:23:48.187 回答