2

最近,我开始尝试使用反向传播来训练神经网络。网络结构是784-512-10,我用的是Sigmoid激活函数。当我在 MNIST 数据集上测试单层网络时,我得到了大约 90%。这个多层网络我的结果大约是 86%,这正常吗?我是否弄错了反向传播部分?

这是我的代码:

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.Scanner;

public class NeuralNetwork{
    public static double learningRate = 0.01;
    public static int epoch = 15;
    public static int ROWS = 28;
    public static int COLUMNS = 28;
    public static int INPUT = ROWS * COLUMNS;
    public static int outNum = 10;
    public static int hiddenNum = 512;
    public static double[][] weights2 = new double[outNum][hiddenNum];
    public static double[] bias2 = new double[outNum];
    public static double[][] weights1 = new double[hiddenNum][INPUT];
    public static double[] bias1 = new double[outNum];
    private static final double TRAININGSIZE = 10;
    public static double[][] inputs = new double[outNum][INPUT];
    private static final double[][] target = new double[outNum][outNum];

    private static final ArrayList<String> filenames = new ArrayList<>();
    private static final ArrayList<Integer> yetDone = new ArrayList<>();
    public static double[] actual = new double[outNum];

    public static Random rand = new SecureRandom();

    public static Scanner input = new Scanner(System.in);
    public static void main(String[]args) throws Exception {
        System.out.println("1. Learn the network");
        System.out.println("2. Guess a number");
        System.out.println("3. Guess file");
        System.out.println("4. Guess All Numbers");
        System.out.println("5. Guess image");
        switch (input.nextInt()){
            case 1:
                learn();
                break;
            case 2:
                guess();
                break;
            case 3:
                guessFile();
                break;
            case 4:
                guessAll();
                break;
        }
    }

    public static void guessAll() throws IOException, ClassNotFoundException {
        System.out.println("Recognizing...");
        /*
        for(int x = 1; x < 60000; x++){
            filenames.add("data/" + String.format("%05d",x) + ".txt");
        }

        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        int correct = 0;
        for (String z : filenames) {
            double[] a = scan(z,0);
            correct += getBestGuess(sigmoid(lay.step(a))) == actual[0] ? 1 : 0;
        }
        System.out.println("Training: " + correct + " / " + filenames.size() + " correct.");
        filenames.clear();

         */

        for(int x = 60000; x < 70000; x++){
            filenames.add("data/" + String.format("%05d",x) + ".txt");
        }

        ObjectInputStream oiss = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network1.ser")));
        Layers lays1 = (Layers) oiss.readObject();
        ObjectInputStream oiss2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network2.ser")));
        Layers lays2 = (Layers) oiss2.readObject();
        int corrects = 0;
        for (String z : filenames) {
            double[] a = scan(z,0);
            corrects += getBestGuess(sigmoid(lays2.step(sigmoid(lays1.step(a))))) == actual[0] ? 1 : 0;
        }
        System.out.println("Testing: " + corrects + " / " + filenames.size() + " correct.");
        
        System.out.println("Done!");
    }

    public static void makeList(){
        for(int index = 0; index < TRAININGSIZE; index++){
            int indices = rand.nextInt(yetDone.size() - 1) + 1;
            filenames.add("data/" + String.format("%05d",yetDone.get(indices)) + ".txt");
            yetDone.remove(indices);
        }
        prepareData();
        for(int indices = 0; indices < outNum; indices++) {
            for(int index = 0; index < outNum; index++){
                target[indices][index] = 0;
            }
            target[indices][(int)actual[indices]] = 1;
        }
    }

    public static void prepareData(){
        for(int index = 0; index < outNum; index++){
            try {
                inputs[index] = scan(filenames.get(index), index);
            } catch (FileNotFoundException ex) {
                ex.printStackTrace();
            }
        }
    }

    public static double[] scan(String filename, int index) throws FileNotFoundException {
        Scanner in = new Scanner(new File(filename));
        double[] a = new double[INPUT];
        for(int i = 0; i < INPUT; i++){
            a[i] = in.nextDouble() / 255;
        }
        actual[index] = in.nextDouble();
        return a;
    }

    public static void guessFile() throws IOException, ClassNotFoundException {
        System.out.print("Enter Filename: ");
        double[] a = scan(input.next(), 0);
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        double[] results = lay.step(a);
        System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
        System.out.println(Arrays.toString(results));
    }

    public static double guess(double[] a) throws IOException, ClassNotFoundException {
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        double[] results = lay.step(a);
        return getBestGuess(sigmoid(results));
    }

    public static void guess() throws IOException, ClassNotFoundException {
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        System.out.println("Input number: ");
        Layers lay = (Layers) ois.readObject();
        double[] a = new double[INPUT];
        for(int index = 0; index < a.length; index++){
            a[index] = input.nextInt();
        }
        double[] results = lay.step(a);
        System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
        System.out.println(Arrays.toString(sigmoid(results)));
    }

    public static void learn() {
        System.out.println("Learning...");
        initialise(weights2, outNum, hiddenNum);
        initialise(bias2);
        initialise(weights1,hiddenNum, INPUT);
        initialise(bias1);

        Layers lay2 = new Layers(weights2, bias2, outNum, hiddenNum);
        Layers lay1 = new Layers(weights1, bias1, hiddenNum, INPUT);

        double[] result2 = new double[lay2.outNum];
        double[] result1 = new double[lay1.outNum];
        double[] a2;
        double[] a1;
        double cost = 0;
        double sumFinal;

        for(int x = 0; x < epoch; x++) {
            yetDone.clear();
            for(int y = 0; y < 60000; y++){
                yetDone.add(y);
            }

            for (int ind = 0; ind < 200; ind++) {
                filenames.clear();
                makeList();
                for (int n = 0; n < lay2.outNum; n++) {
                    a1 = inputs[n]; //number
                    result1 = sigmoid(lay1.step(a1));
                    a2 = result1;
                    result2 = sigmoid(lay2.step(a2));

                    for (int i = 0; i < lay2.outNum; i++) {
                        for (int j = 0; j < lay2.INPUT; j++) {
                            weights2[i][j] += learningRate * a2[j] * (target[n][i] - result2[i]);
                            cost += Math.pow((target[n][i] - result2[i]), 2);
                        }
                    }

                    for(int i = 0; i < lay1.outNum; i++){
                        for(int j = 0; j < lay1.INPUT; j++){
                            sumFinal = 0;
                            for(int k = 0; k < lay2.outNum; k++){
                                // weight * derivSigma(outputHiddenLayer) * 2(out - expected)
                                sumFinal += result1[k] * (1 - result1[k]) * 2 * (result2[k] - target[n][k]); // * weights2[k][i]
                            }
                            weights1[i][j] -= learningRate * a1[j] * sumFinal * result1[i] * (1 - result1[i]);
                        }
                    }
                }
                lay1.update(weights1, bias1);
                lay2.update(weights2, bias2);
            }
            System.out.println("Epoch " + x + ": " + cost);
            cost = 0;
        }
        System.out.println(Arrays.toString(result1));
        System.out.println(Arrays.toString(result2));

        for(double[] arr : inputs) {
            System.out.println("This is a " + getBestGuess(sigmoid(lay2.step(sigmoid(lay1.step(arr))))) + "!");
        }

        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network1.ser")))) {
            oos.writeObject(lay1);
        } catch (IOException ex) {
            ex.printStackTrace();
        }

        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network2.ser")))) {
            oos.writeObject(lay2);
        } catch (IOException ex) {
            ex.printStackTrace();
        }

        System.out.println("Done! Saved to file.");
    }

    public static double sigmoid(double x){
        return 1 / (1 + Math.exp(-x));
    }

    public static double[] sigmoid(double[] weights){
        for(int index = 0; index < weights.length; index++){
            weights[index] = sigmoid(weights[index]);
        }
        return weights;
    }

    public static void initialise(double[] bias){
        Random random = new Random();
        for(int index = 0; index < bias.length; index++){
            bias[index] = random.nextGaussian();
        }
    }

    public static void initialise(double[][] weights, int outNum, int INPUT){
        Random random = new Random();
        for(int index = 0; index < outNum; index++){
            for(int indice = 0; indice < INPUT; indice++){
                weights[index][indice] = random.nextGaussian();
            }
        }
    }

    public static int getBestGuess(double[] result){
        double k = Integer.MIN_VALUE;
        double index = 0;
        int current = 0;
        for(double a : result){
            if(k < a){
                k = a;
                index = current;
            }
            current++;
        }

        return (int)index;
    }
}

class Layers implements Serializable {
    private static final long serialVersionUID = 8L;
    double[][] weights;
    double[] bias;
    int outNum;
    int INPUT;

    public Layers(double[][] weights, double[] bias, int outNum, int INPUT){
        this.weights = weights;
        this.bias = bias;
        this.outNum = outNum;
        this.INPUT = INPUT;
    }

    public void update(double[][] weights, double[] bias){
        this.weights = weights;
        this.bias = bias;
    }

    public double[] step(double[] aa){
        double[] out = new double[outNum];
        for (int index = 0; index < outNum; index++) {
            for (int indices = 0; indices < INPUT; indices++) {
                out[index] += weights[index][indices] * aa[indices];
            }
        }
        return out;
    }
}

提前致谢!

4

0 回答 0