2

我一直在尝试训练以下网络并获得合适的权重,但它一直在运行。谁能告诉我代码中可能有什么问题?这里 {8, 1} 是输入,{-1}} 是使用符号函数的预期输出。

import java.util.Arrays;

public class ANN {

    public static void main(String args[]) {

        double threshold = 1.2;
        double learningRate = 0.08;

        // Init weights

        double[] weights = { -1.4, 1.8 };

        int[][][] trainingData = {
            {{8, 1}, {-1}},
            {{3, 2}, {-1}},
            {{6, 3}, {-1}},
            {{1, 4}, {-1}},
            {{9, 5}, {1}},
            {{5, 6}, {1}},
            {{2, 7}, {1}},
            {{4, 8}, {1}},
            {{7, 9}, {1}},
        };

        // Start training loop
        while (true) {
            int errorCount = 0;
            // Loop over training data
            for (int i = 0; i < trainingData.length; i++) {
                System.out.println("Starting weights: " + Arrays.toString(weights));
                // Calculate weighted input
                double weightedSum = 0;
                for (int ii = 0; ii < trainingData[i][0].length; ii++) {
                    weightedSum += trainingData[i][0][ii] * weights[ii];
                }

                // Calculate output
                int output = 0;
                if (threshold <= weightedSum) {
                    output = 1;
                }

                System.out.println("Target output: " + trainingData[i][1][0]
                        + ", " + "Actual Output: " + output);

                // Calculate error
                int error = trainingData[i][1][0] - output;
                System.out.println("Error:  " + error);
                // Increase error count for incorrect output
                if (error != 0) {
                    errorCount++;
                }

                // Update weights
                for (int ii = 0; ii < trainingData[i][0].length; ii++) {
                    weights[ii] += learningRate * error
                            * trainingData[i][0][ii];
                }

                System.out.println("New weights: " + Arrays.toString(weights));
                System.out.println();
            }

            // If there are no errors, stop
            if (errorCount == 0) {
                System.out
                        .println("Final weights: " + Arrays.toString(weights));
                System.exit(0);
            }
        }
    }

}

编辑:我相信问题出在计算输出的代码片段上。应该翻转它,以便如果总和大于阈值输出为 1,否则为 0。

    // Calculate output
                int output = 0;
                if (weightedSum > threshold) {
                    output = 1;
                }
4

2 回答 2

1

我已经运行了您的代码并在 (errorCount==0) 检查之前添加了一行:

System.out.println(errorCount);

这似乎在 6 和 7 之间波动,这意味着无论完成多少训练,神经网络总是会生成对训练数据的无效估计。如果训练没有达到 100% 正确的训练数据,那么这预计会永远持续下去。

希望这可以帮助!

于 2014-09-11T05:15:47.380 回答
1

你的错误可以是正面的也可以是负面的。在第一次运行中,错误为 -1。因此,errorCount 增加,退出循环的代码永远不会执行。

完全训练的条件应该基于错误本身,而不是错误计数。当错误达到最低水平(您将根据输入设置)时,培训将被视为完成。

于 2014-09-11T05:22:01.180 回答