import java.lang.Math;
import Jama.Matrix;
import java.util.Random;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
public class Neural_Network {
private int inputnodes;
private int hiddennodes;
private int outputnodes;
private double learningrate;
private Matrix wih;
private Matrix who;
public Neural_Network(int inputnodes, int hiddennodes, int outputnodes, double learningrate) {
this.inputnodes = inputnodes;
this.hiddennodes = hiddennodes;
this.outputnodes = outputnodes;
this.learningrate = learningrate;
wih = random(hiddennodes, inputnodes);
who = random(outputnodes, hiddennodes);
}
给矩阵一个随机值
private Matrix random(int i, int j) {
Matrix matrix = new Matrix(i, j);
double array[][] = matrix.getArray();
Random rand = new Random();
for(int x=0; x<i; x++) {
for(int y=0; y<j; y++) {
array[x][y] = (rand.nextDouble()+0.1);
}
}
return matrix;
}
sigmoid函数
private Matrix activate_function(Matrix mat) {
double arr[][] = mat.getArray();
for(int i=0; i<mat.getRowDimension(); i++) {
for(int j=0; j<mat.getColumnDimension(); j++) {
arr[i][j] = (1/1+Math.exp(-arr[i][j]));
}
}
return mat;
}
网络训练(反向传播)
private void training(Matrix input_list, Matrix target_list) {
Matrix inputs = input_list.transpose();
Matrix targets = target_list.transpose();
Matrix hidden_inputs = wih.times(inputs);
Matrix hidden_outputs = activate_function(hidden_inputs);
Matrix final_inputs = who.times(hidden_outputs);
Matrix final_outputs = activate_function(final_inputs);
Matrix output_errors = targets.minus(final_outputs);
Matrix hidden_errors = who.transpose().times(output_errors);
Matrix mat_1 = new Matrix(final_outputs.getRowDimension(), final_outputs.getColumnDimension(), 1);
Matrix mat_2 = new Matrix(hidden_outputs.getRowDimension(), hidden_outputs.getColumnDimension(), 1);
who.plus(output_errors.arrayTimes(final_outputs.arrayTimes(mat_1.minus(final_outputs))).times(hidden_outputs.transpose()).times(learningrate));
wih.plus(hidden_errors.arrayTimes(hidden_outputs.arrayTimes(mat_2.minus(hidden_outputs))).times(inputs.transpose()).times(learningrate));
}
private Matrix query(Matrix input_list) {
Matrix inputs = input_list.transpose();
Matrix hidden_inputs = wih.times(inputs);
Matrix hidden_outputs = activate_function(hidden_inputs);
Matrix final_inputs = who.times(hidden_outputs);
Matrix final_outputs = activate_function(final_inputs);
return final_outputs;
}
public static void main(String[] args) {
// TODO Auto-generated method stub
int epochs = 5;
int i=0;
double[] data = new double[784];
Neural_Network ann = new Neural_Network(784, 200, 10, 0.2);
for(int x=0; x<epochs; x++) {
try {
File csv = new File("mnist_train_100.csv");
BufferedReader br = new BufferedReader(new FileReader(csv));
String line = "";
while((line = br.readLine()) != null) {
String[] token = line.split(",", -1);
double target = Double.parseDouble(token[0]);
double target_array[] = {0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01};
target_array[(int) target] = 0.99;
for(i=1; i<785; i++) {
data[i-1] = Double.parseDouble(token[i]);
}
i=1;
Matrix input_list = new Matrix(data, 1);
Matrix target_list = new Matrix(target_array, 1);
ann.training(input_list, target_list);
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
try {
File csv = new File("mnist_test_10.csv");
BufferedReader br = new BufferedReader(new FileReader(csv));
String line = "";
while((line = br.readLine()) != null) {
String[] token = line.split(",", -1);
double target = Double.parseDouble(token[0]);
for(i=1; i<785; i++) {
data[i-1] = Double.parseDouble(token[i]);
}
i=1;
Matrix input_list = new Matrix(data, 1);
Matrix outputs = ann.query(input_list);
double value = outputs.norm2();
System.out.println(value);
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
}
这是我的代码。我不知道为什么这段代码不起作用 https://github.com/makeyourownneuralnetwork 我正在将站点的神经网络 python 代码转换为 java 代码。
我使用 JAMA 库创建了代码,但最终的输出值都是一样的。请帮我。