-1

我正在尝试用 FANN 逼近平方函数。代码如下:

#include "../FANN-2.2.0-Source/src/include/doublefann.h"
#include "../FANN-2.2.0-Source/src/include/fann_cpp.h"
#include <cstdlib>
#include <iostream>

using namespace std;
using namespace FANN;

//Remember: fann_type is double!
int main(int argc, char** argv) {
    //create a test network: [1,2,1] MLP
    neural_net * net = new neural_net;
    const unsigned int layers[3] = {1,3,1};
    net->create_standard_array(3,layers);

    //net->create_standard(num_layers, num_input, num_hidden, num_output);

    net->set_learning_rate(0.7f);

    net->set_activation_steepness_hidden(0.7);
    net->set_activation_steepness_output(0.7);

    net->set_activation_function_hidden(SIGMOID_SYMMETRIC_STEPWISE);
    net->set_activation_function_output(SIGMOID_SYMMETRIC_STEPWISE);
    net->set_training_algorithm(TRAIN_QUICKPROP);

    //cout<<net->get_train_error_function()
    //exit(0);
    //test the number 2
    fann_type * testinput = new fann_type;
    *testinput = 2;
    fann_type * testoutput = new fann_type;
    *testoutput = *(net->run(testinput));
    double outputasdouble = (double) *testoutput;
    cout<<"Test output: "<<outputasdouble<<endl;

    //make a training set of x->x^2
    training_data * squaredata = new training_data;
    squaredata->read_train_from_file("trainingdata.txt");

    net->train_on_data(*squaredata,1000,100,0.001);

    cout<<endl<<"Easy!";
    return 0;
}

trainingdata.txt 是这样的:

10 1 1
1 1
2 4
3 9
4 16
5 25
6 36
7 49
8 64
9 81
10 100

我觉得我用 API 做的一切都是正确的。然而,当我运行它时,我得到了巨大的错误,似乎永远不会随着训练而减少。

Test output: -0.0311087
Max epochs     1000. Desired error: 0.0010000000.
Epochs            1. Current error: 633.9928588867. Bit fail 10.
Epochs          100. Current error: 614.3250122070. Bit fail 9.
Epochs          200. Current error: 614.3250122070. Bit fail 9.
Epochs          300. Current error: 614.3250122070. Bit fail 9.
Epochs          400. Current error: 614.3250122070. Bit fail 9.
Epochs          500. Current error: 614.3250122070. Bit fail 9.
Epochs          600. Current error: 614.3250122070. Bit fail 9.
Epochs          700. Current error: 614.3250122070. Bit fail 9.
Epochs          800. Current error: 614.3250122070. Bit fail 9.
Epochs          900. Current error: 614.3250122070. Bit fail 9.
Epochs         1000. Current error: 614.3250122070. Bit fail 9.

Easy!

我做错了什么?

4

1 回答 1

3

如果您对输出层使用 sigmoid 函数,则输出将提供 (0,1) 的范围。

您可能有两个选择,(1) 将所有输出除以一个常数,例如 1e4。当一个测试数据到来时,你也将它除以 1e4。问题是您可能无法预测大于 100 的平方数 (100^2=1e4);(2) 使隐藏层和输出层都为线性,并且网络将自动学习权重以给出您拥有的任何输出值.

于 2013-11-28T23:51:07.540 回答