1

我正在使用 OpenCV4Android,我正在尝试做一个神经网络的小例子来计算算术平均值。所以,我决定使用CvANN_MLP来创建网络。一切顺利,但是当我训练它时,它失败了,下一个例外:

OpenCV 错误:CvANN_MLP::prepare_to_train 中的参数错误(输出训练数据应为浮点矩阵,其行数等于训练样本数,列数等于最后(输出)层的大小)

我检查了输出训练,它的类型是 CV_32FC1。行数和列数也是正确的。当然我不知道错误在哪里。

这是我的代码,我希望有人可以帮助我。谢谢!

        int train_sample_count = 10;

        float td[][] = new float[10][3];

                    //I've created this method to populate td
        populateTrainingData(td);

        Mat trainData = new Mat(train_sample_count, 2, CvType.CV_32FC1);
        Mat trainClasses = new Mat(train_sample_count, 1, CvType.CV_32FC1);
        Mat sampleWts = new Mat(train_sample_count, 1, CvType.CV_32FC1);
        Mat neuralLayers = new Mat(3, 1, CvType.CV_32SC1);

        // input layer has 2 cells
        neuralLayers.put(0, 0, 2);
        // hidden layer has 2 cells
        neuralLayers.put(1, 0, 2);
        // output layer has 2 cells
        neuralLayers.put(2, 0, 2);

        // assembles the trainData,trainClasses and weights

        for (int i = 0; i < train_sample_count; i++) {
            trainData.put(i, 0, td[i][0]);
            trainData.put(i, 1, td[i][1]);
            trainClasses.put(i, 0, td[i][2]);
            sampleWts.put(i, 0, 1);
        }

        Log.d(DEBUG_TAG, "Assemblage is finished");

        // creates neural network with the layers of neuralLayers
        CvANN_MLP machineBrain = new CvANN_MLP(neuralLayers);

        Log.d(DEBUG_TAG, "Neural network is created");

        // trains neural network with my data
        // parameters for neural network
        CvANN_MLP_TrainParams trainParams = new CvANN_MLP_TrainParams();
        // backward propagation
        trainParams.set_train_method(CvANN_MLP_TrainParams.BACKPROP);
        // number of iterations and sigmoidal update
        TermCriteria termC = new TermCriteria(TermCriteria.EPS
                + TermCriteria.COUNT, 10000, 1.0);
        trainParams.set_term_crit(termC);

        // optional value which is zero
        Mat simpleIndex = new Mat();
        // setting up the neural network
        Log.d(DEBUG_TAG, "Setting up is finished");
        Log.d(DEBUG_TAG, "Type of trainClasses: "
                + (trainClasses.type() == CvType.CV_32FC1));
        machineBrain.train(trainData, trainClasses, sampleWts);
4

0 回答 0