1
import java.lang.Math;
import Jama.Matrix;
import java.util.Random;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;


public class Neural_Network {

   private int inputnodes;
   private int hiddennodes;
   private int outputnodes;
   private double learningrate;
   private Matrix wih;
   private Matrix who;

   public Neural_Network(int inputnodes, int hiddennodes, int outputnodes, double learningrate) {
      this.inputnodes = inputnodes;
      this.hiddennodes = hiddennodes;
      this.outputnodes = outputnodes;
      this.learningrate = learningrate;

      wih = random(hiddennodes, inputnodes);
      who = random(outputnodes, hiddennodes);
   }

给矩阵一个随机值

   private Matrix random(int i, int j) {
      Matrix matrix = new Matrix(i, j);
      double array[][] = matrix.getArray();
      Random rand = new Random();
      for(int x=0; x<i; x++) {
         for(int y=0; y<j; y++) {
            array[x][y] = (rand.nextDouble()+0.1);
         }
      }
      return matrix;
   }

sigmoid函数

   private Matrix activate_function(Matrix mat) {
      double arr[][] = mat.getArray();

      for(int i=0; i<mat.getRowDimension(); i++) {
          for(int j=0; j<mat.getColumnDimension(); j++) {
              arr[i][j] = (1/1+Math.exp(-arr[i][j]));
          }
      }
      return mat;
    }

网络训练(反向传播)

   private void training(Matrix input_list, Matrix target_list) {
       Matrix inputs = input_list.transpose();
       Matrix targets = target_list.transpose();

       Matrix hidden_inputs = wih.times(inputs);
       Matrix hidden_outputs = activate_function(hidden_inputs);

       Matrix final_inputs = who.times(hidden_outputs);
       Matrix final_outputs = activate_function(final_inputs);

       Matrix output_errors = targets.minus(final_outputs);
       Matrix hidden_errors = who.transpose().times(output_errors);

       Matrix mat_1 = new Matrix(final_outputs.getRowDimension(), final_outputs.getColumnDimension(), 1);
       Matrix mat_2 = new Matrix(hidden_outputs.getRowDimension(), hidden_outputs.getColumnDimension(), 1);

       who.plus(output_errors.arrayTimes(final_outputs.arrayTimes(mat_1.minus(final_outputs))).times(hidden_outputs.transpose()).times(learningrate));
       wih.plus(hidden_errors.arrayTimes(hidden_outputs.arrayTimes(mat_2.minus(hidden_outputs))).times(inputs.transpose()).times(learningrate));
   }

   private Matrix query(Matrix input_list) {
      Matrix inputs = input_list.transpose();

      Matrix hidden_inputs = wih.times(inputs);
      Matrix hidden_outputs = activate_function(hidden_inputs);

      Matrix final_inputs = who.times(hidden_outputs);
      Matrix final_outputs = activate_function(final_inputs);

      return final_outputs;
   }

   public static void main(String[] args) {
      // TODO Auto-generated method stub
      int epochs = 5;
      int i=0;
      double[] data = new double[784];

      Neural_Network ann = new Neural_Network(784, 200, 10, 0.2);



      for(int x=0; x<epochs; x++) {
          try {
              File csv = new File("mnist_train_100.csv");
              BufferedReader br = new BufferedReader(new FileReader(csv));
              String line = "";
              while((line = br.readLine()) != null) {
                  String[] token = line.split(",", -1);

                  double target = Double.parseDouble(token[0]);
                  double target_array[] = {0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01};
                  target_array[(int) target] = 0.99;

                  for(i=1; i<785; i++) {
                      data[i-1] = Double.parseDouble(token[i]);
                  }
                  i=1;

                  Matrix input_list = new Matrix(data, 1);
                  Matrix target_list = new Matrix(target_array, 1);

                  ann.training(input_list, target_list);
              }
              br.close();
          } catch (FileNotFoundException e) {
              e.printStackTrace();
          } catch (IOException e) {
              e.printStackTrace();
          }
      }

      try {
          File csv = new File("mnist_test_10.csv");
          BufferedReader br = new BufferedReader(new FileReader(csv));
          String line = "";
          while((line = br.readLine()) != null) {
              String[] token = line.split(",", -1);

              double target = Double.parseDouble(token[0]);

              for(i=1; i<785; i++) {
                  data[i-1] = Double.parseDouble(token[i]);
              }
              i=1;
              Matrix input_list = new Matrix(data, 1);
              Matrix outputs = ann.query(input_list);
              double value = outputs.norm2();
              System.out.println(value);
          }
          br.close();
      } catch (FileNotFoundException e) {
          e.printStackTrace();
      } catch (IOException e) {
          e.printStackTrace();
      }
   }
}

这是我的代码。我不知道为什么这段代码不起作用 https://github.com/makeyourownneuralnetwork 我正在将站点的神经网络 python 代码转换为 java 代码。

我使用 JAMA 库创建了代码,但最终的输出值都是一样的。请帮我。

4

0 回答 0