1

我最近使用 youtube 上的一系列视频编写了一个神经网络,该频道是编码火车。它是用 js 写的,我是用 java 写的。它有时工作正常,但有时我得到 NaN 作为输出,我可以弄清楚为什么?

任何人都可以帮忙吗?一些矩阵数学和神经网络类有一个矩阵类,它本身带有一个测试问题。如果 0 大于 1,则第一个输出为 1,否则,第二个输出为 1。

编辑:我发现问题出在哪里,但我仍然无法弄清楚为什么会这样?!in 发生在我在 Matrix 类中的静态点积方法中。有时一个或两个矩阵数据都是 NaN!

编辑 2:我检查过,输入在构造函数中是有效的,但在前馈方法中它们有时是 NaN !!!可能是因为我使用的是一台 10 年前的笔记本电脑吗?因为代码似乎没有任何问题。

已解决:我找到了问题!在前馈中,我没有为输出矩阵映射 sigmoid -_-

public class NeuralNetwork {

//private int inputNodes, hiddenNodes, outputNodes;
private Matrix weightsIH, weightsHO, biasH, biasO;
private double learningRate = 0.1;

public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes) {
    //this.inputNodes = inputNodes;
    //this.hiddenNodes = hiddenNodes;
    //this.outputNodes = outputNodes;

    weightsIH = new Matrix(hiddenNodes, inputNodes);
    weightsHO = new Matrix(outputNodes, hiddenNodes);
    weightsIH.randomize();
    weightsHO.randomize();

    biasH = new Matrix(hiddenNodes, 1);
    biasO = new Matrix(outputNodes, 1);

    biasH.randomize();
    biasO.randomize();
}

public void setLearningRate(double learningRate) {
    this.learningRate = learningRate;
}

public double sigmoid(double x) {
    return 1 / (1 + Math.exp(-x));
}

public double dsigmoid(double y) {
    return y * (1 - y);
}

public double[] feedForward(double[] inputArray) throws Exception {

    Matrix inputs = Matrix.fromArray(inputArray);
    Matrix hidden = Matrix.dot(weightsIH, inputs);
    hidden.add(biasH);

    hidden.map(f -> sigmoid(f));

    Matrix output = Matrix.dot(weightsHO, hidden);
    output.add(biasO);

    return output.toArray();
}

public void train(double[] inputArray, double[] targetsArray) throws Exception {

    Matrix targets = Matrix.fromArray(targetsArray);

    // feed forward algorithm //
    Matrix inputs = Matrix.fromArray(inputArray);
    Matrix hidden = Matrix.dot(weightsIH, inputs);
    hidden.add(biasH);

    hidden.map(f -> sigmoid(f));

    Matrix outputs = Matrix.dot(weightsHO, hidden);
    outputs.add(biasO);
    // feed forward algorithm //

    // Calculate outputs ERRORS
    Matrix outputErrors = Matrix.subtract(targets, outputs);

    // Calculate outputs Gradients
    Matrix outputsGradients = Matrix.map(outputs, f -> dsigmoid(f));
    outputsGradients.multiply(outputErrors);
    outputsGradients.multiply(learningRate);

    // Calculate outputs Deltas
    Matrix hidden_t = Matrix.transpose(hidden);
    Matrix weightsHO_deltas = Matrix.dot(outputsGradients, hidden_t);

    // adjust outputs weights
    weightsHO.add(weightsHO_deltas);
    // adjust outputs bias
    biasO.add(outputsGradients);

    // Calculate hidden layer ERRORS
    Matrix weightsHO_t = Matrix.transpose(weightsHO);
    Matrix hiddenErrors = Matrix.dot(weightsHO_t, outputErrors);

    // Calculate hidden Gradients
    Matrix hiddenGradients = Matrix.map(hidden, f -> dsigmoid(f));
    hiddenGradients.multiply(hiddenErrors);
    hiddenGradients.multiply(learningRate);

    // Calculate hidden Deltas
    Matrix inputs_t = Matrix.transpose(inputs);
    Matrix weightsIH_deltas = Matrix.dot(hiddenGradients, inputs_t);

    // adjust hidden weights
    weightsIH.add(weightsIH_deltas);
    // adjust hidden bias
    biasH.add(hiddenGradients);

}

public static void print(double[] data) {
    for (double d : data) {
        System.out.print(d + " ");
    }
    System.out.println();
}

public static void main(String[] args) {
    NeuralNetwork nn = new NeuralNetwork(3, 4, 2);
    double[][] trainingInputs = {{0, 0, 0}, {0, 0, 1}, {0, 1, 0}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
    double[][] targets = {{1, 0}, {1, 0}, {1, 0}, {0, 1}, {1, 0}, {0, 1}, {0, 1}, {1, 0}};

    for (int i = 0; i < 10000; i++) {
        for (int j = 0; j < trainingInputs.length; j++) {
            try {
                nn.train(trainingInputs[j], targets[j]);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    double[] output;

    try {
        output = nn.feedForward(new double[]{0, 0, 0});
        print(output);
        output = nn.feedForward(new double[]{0, 0, 1});
        print(output);
        output = nn.feedForward(new double[]{0, 1, 0});
        print(output);
        output = nn.feedForward(new double[]{0, 1, 1});
        print(output);
        output = nn.feedForward(new double[]{1, 0, 0});
        print(output);
        output = nn.feedForward(new double[]{1, 0, 1});
        print(output);
        output = nn.feedForward(new double[]{1, 1, 0});
        print(output);
        output = nn.feedForward(new double[]{1, 1, 1});
        print(output);
    } catch (Exception e) {
        e.printStackTrace();
    }
} }


public class Matrix {

public double[][] data;

public Matrix(int row, int col) {
    data = new double[row][col];
}

public Matrix(double[][] data) {

    this.data = data;
}

public void randomize() {
    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] = new Random().nextDouble() * 2 - 1;
        }
    }
}

public Matrix transpose() {
    Matrix result = new Matrix(data[0].length, data.length);

    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            result.data[j][i] = data[i][j];
        }
    }
    return result;
}

public static Matrix transpose(Matrix m) {
    Matrix result = new Matrix(m.data[0].length, m.data.length);

    for (int i = 0; i < m.data.length; i++) {
        for (int j = 0; j < m.data[0].length; j++) {
            result.data[j][i] = m.data[i][j];
        }
    }
    return result;
}

public void add(double n) {
    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] += n;
        }
    }
}

public void subtract(double n) {
    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] -= n;
        }
    }
}

public void add(Matrix m) throws Exception {
    if (!(data.length == m.data.length && data[0].length == m.data[0].length)) 
        throw new Exception("columns and rows don't match!");

    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] += m.data[i][j];
        }
    }
}

public void subtract(Matrix m) throws Exception {
    if (!(data.length == m.data.length && data[0].length == m.data[0].length))
        throw new Exception("columns and rows don't match!");

    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] -= m.data[i][j];
        }
    }
}

public static Matrix add(Matrix m1, Matrix m2) throws Exception {
    if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length)) 
        throw new Exception("columns and rows don't match!");

    Matrix result = new Matrix(m1.data.length, m1.data[0].length);

    for (int i = 0; i < result.data.length; i++) {
        for (int j = 0; j < result.data[0].length; j++) {
            result.data[i][j] = m1.data[i][j] + m2.data[i][j];
        }
    }

    return result;
}

public static Matrix subtract(Matrix m1, Matrix m2) throws Exception {
    if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length)) 
        throw new Exception("columns and rows don't match!");

    Matrix result = new Matrix(m1.data.length, m1.data[0].length);

    for (int i = 0; i < result.data.length; i++) {
        for (int j = 0; j < result.data[0].length; j++) {
            result.data[i][j] = m1.data[i][j] - m2.data[i][j];
        }
    }

    return result;
}

public void multiply(double n) {
    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] *= n;
        }
    }
}

public void multiply(Matrix m) throws Exception {
    if (!(data.length == m.data.length && data[0].length == m.data[0].length)) 
        throw new Exception("columns and rows don't match!");

    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] *= m.data[i][j];
        }
    }
}

public static Matrix multiply(Matrix m1, Matrix m2) throws Exception {
    if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length)) 
        throw new Exception("columns and rows don't match!");

    Matrix result = new Matrix(m1.data.length, m1.data[0].length);
    for (int i = 0; i < m1.data.length; i++) {
        for (int j = 0; j < m1.data[0].length; j++) {
            result.data[i][j] = m1.data[i][j] * m2.data[i][j];
        }
    }

    return result;
}

public Matrix dot(Matrix m) throws Exception {
    if (data[0].length != m.data.length) 
        throw new Exception("columns and rows don't match!");

    Matrix result = new Matrix(data.length, m.data[0].length);

    for (int i = 0; i < result.data.length; i++) {
        for (int j = 0; j < result.data[0].length; j++) {
            double sum = 0;
            for (int k = 0; k < data[0].length; k++) {
                sum += data[i][k] * m.data[k][j];
            }
            result.data[i][j] = sum;
        }
    }

    return result;
}

public static Matrix dot(Matrix m1, Matrix m2) throws Exception {
    if (m1.data[0].length != m2.data.length) 
        throw new Exception("columns and rows don't match!");

    Matrix result = new Matrix(m1.data.length, m2.data[0].length);

    for (int i = 0; i < result.data.length; i++) {
        for (int j = 0; j < result.data[0].length; j++) {
            double sum = 0;
            for (int k = 0; k < m1.data[0].length; k++) {
                sum += m1.data[i][k] * m2.data[k][j];
            }
            result.data[i][j] = sum;
        }
    }

    return result;
}

public static interface Func {

    public double method(double d);
}

public void map(Func f) {
    for (int i = 0 ; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            data[i][j] = f.method(data[i][j]);
        }
    }
}

public static Matrix map(Matrix m, Func f) {
    Matrix result = new Matrix(m.data.length, m.data[0].length);
    for (int i = 0 ; i < m.data.length; i++) {
        for (int j = 0; j < m.data[0].length; j++) {
            result.data[i][j] = f.method(m.data[i][j]);
        }
    }

    return result;
}

public static Matrix fromArray(double[] arr) {

    Matrix res = new Matrix(arr.length, 1);
    for (int i = 0; i < arr.length; i++) {
        res.data[i][0] = arr[i];
    }
    return res;
}

public double[] toArray() {
    double[] res = new double[data.length];

    for (int i = 0; i < data.length; i++) {
        res[i] = data[i][0];
    }

    return res;
}

public void print() {
    for (int i = 0; i < data.length; i++) {
        for (int j = 0; j < data[0].length; j++) {
            System.out.print(data[i][j] + " ");
        }
        System.out.println();
    }
}}
4

1 回答 1

1

你有几个选项来调试它,它们甚至可以一起使用。

添加调试输出

为您的所有计算添加调试输出,以便您可以查看究竟是什么导致了意外值。例如,你有...

public double sigmoid(double x) {
    return 1 / (1 + Math.exp(-x));
}

但你可以通过制作它来看看它在做什么......

public double sigmoid(double x) {
    double sigmoid = 1 / (1 + Math.exp(-x));
    System.out.println("1 / (1 + Math.exp(" + (-x) + ")) = " + sigmoid);
    return sigmoid;
}

在执行可能导致意外值的计算的任何地方执行此操作。

我建议你像这样输出一些调试信息,然后在输出内容中搜索 NaN。如果您可以将输出放入文件中,然后在文字处理器中打开该文件以进行文本搜索,这将更加容易 - 如果您在命令行上运行,java MyApp > myapp_log.txt则可以在文本编辑器中打开myapp_log.txt以执行文本搜索。

或者为了使输出更易于处理,您可以使调试逻辑仅在找到 NaN 时输出,例如...

public double sigmoid(double x) {
    double sigmoid = 1 / (1 + Math.exp(-x));
    if(sigmoid == Double.NaN)
        System.out.println("1 / (1 + Math.exp(" + (-x) + ")) = " + sigmoid);
    return sigmoid;
}

只要记住对您计算的所有内容都执行此操作,包括您的dsigmoid、您的add等,无论您在哪里进行任何类型的计算。如果你在任何地方都放了足够多的内容,那么你会发现问题并看到像“1 / (1 + Math.exp(NaN)) = NaN”这样的行输出。

使用调试器

使用调试器可以做很多事情。您可以运行您的程序,但一次只执行一行,并在发生时检查每个变量和结果。根据矩阵的大小以及这些函数被调用的次数,这可能需要付出很多努力。

或者您可以在变量上设置“监视”以在某个值等于 NaN 时让程序停止,然后检查当时程序的状态 - 我不确定是否有任何 Java 调试器具有此功能不过,因为我只在 C 或汇编中进行过这种类型的调试,所以你必须弄清楚你是否可以访问这样的调试器。

于 2019-12-05T19:36:57.010 回答