我在 java 中编写了一个非常简单的神经网络,它有 1 个神经元、2 个输入和 1 个偏差,试图对一条线左侧或右侧的点进行分类。问题是神经网络识别直线的斜率而不识别 y 截距(例如直线的函数可能是 y = m*x + c 并且 NN 识别 m 但不识别 c)。
我尝试使用偏差值 = 1,以使 NN 能够计算偏差的权重,偏差应该是 y 截距。但它没有。您会看到我对 Java 编程非常陌生。不幸的是,NN 也是新手。在那种情况下,我想我的问题在于对 NN 中偏差的基本方法的理解。
备注:在代码最后的输出行中,我希望在函数 y = 3*x + 5 的情况下出现以下数字: weight[0]=3 (which is the m) , weight[1]= 1(这是 y 的因数)和 weight[2]=5(这是 c)。重量[2] 总是错误的。
package nn2;
public class anfang_eng {
public static void main(String[] args)
{
double[][] points = new double[5][10000];
double[] weights = new double[3];
double[][] normpoints = new double[5][10000];
// create 1000 dots with desired result for training afterwards
points = createPoints();
// the before randomly created x and y values of the 1000 dots
// shall be normalized between 0 and 1
normpoints = normalize(points);
// create two random initial weights
weights = createinitialWeights();
// training function, calculation of three different weights
calculateWeights(normpoints, weights);
testnewPoints(weights);
}
// thats the function of the line, that seperates all dots in
// two groups: all the dots at the left side of the line and all the dots
// at the right side.
static double function(double x, double y)
{
double result;
result = 3*x - y + 5;
return result;
}
static double[][] createPoints()
{
// 1. step: lets create for training reasons some dots and calculate
// the result for each dot (result is either "1" or "-1").
// point[0] is x, point[1] is y, point[2] is bias and point[3] is
// result (left or right side of the function above
int x;
int y;
int quantity= 1000;
double[][] point = new double[5][quantity];
for (int i=0; i<quantity; i++)
{
x = (int) (2000 * Math.random()-1000);
y = (int) (2000 * Math.random()-1000);
point[0][i] = x;
point[1][i] = y;
// point[2] is our bias
point[2][i] = 1;
// all dots which are at the right side of the function above get
// result "1". otherwise "-1"
if ( function(x,y) > 0)
point[3][i] = 1;
else
point[3][i] =-1;
// point[3] contains the result
}
// in the variable point, there are e.g. 1000 or 5000 dots with x, y,
// bias and the result (1=left side and -1=right side)
return point;
}
// normalize x and y values between 0 and 1
static double[][] normalize(double[][]points)
{
int quantity = points[0].length;
double minpoint_x=1000;
double minpoint_y=1000;
double maxpoint_x=-1000;
double maxpoint_y=-1000;
double[][] normpoints = new double[5][quantity];
minpoint_x= points[0][0];
minpoint_y = points[1][0];
maxpoint_x = points[0][0];
maxpoint_y = points[1][0];
for (int i=0; i<quantity;i++)
{
if (points[0][i]<minpoint_x)
minpoint_x=points[0][i];
if (points[1][i]<minpoint_y)
minpoint_y=points[1][i];
if (points[0][i]>maxpoint_x)
maxpoint_x=points[0][i];
if (points[1][i]>maxpoint_y)
maxpoint_y=points[1][i];
}
for (int u=0; u<quantity; u++)
{
normpoints [0][u]= (points[0][u]-minpoint_x)/(maxpoint_x-minpoint_x);
normpoints [1][u]= (points[1][u]-minpoint_y)/(maxpoint_y-minpoint_y);
normpoints [2][u] = 1; //bias is always 1
normpoints [3][u] = points[3][u];
}
return normpoints;
}
static double[] createinitialWeights()
{
// creation of initial weights between -1 and 1
double[] weight = new double[3];
weight[0] = 2*Math.random()-1;
weight[1] = 2*Math.random()-1;
weight[2] = 2*Math.random()-1;
return weight;
}
static void calculateWeights(double[][] normpoints, double[] weight)
// new weight = weight + error * input * learning constant
// c is learning constant
{
double c = 0.01;
double error = 0;
double sumguess = 0;
double guess = 0;
int quantity = normpoints[0].length;
for (int i=0; i < quantity; i++)
{
// normpoint[0][i] stands for the factor at x, normpoint[0][i] is
// for y and normpoint[2][i] is for bias
sumguess = normpoints[0][i] * weight[0] + normpoints[1][i]*weight[1] + normpoints[2][i]*weight[2];
if (sumguess > 0)
guess = 1;
else
guess = -1;
error = normpoints[3][i]- guess;
weight[0] = weight[0] + error * normpoints[0][i] * c;
weight[1] = weight[1] + error * normpoints[1][i] * c;
weight[2] = weight[2] + error * normpoints[2][i] * c;
System.out.println("i: " + i + " ;value_normpoint[0]:" + normpoints[0][i]+ " ;value_normpoint[1]" + normpoints[1][i]+ " ;value_normpoint[2]" + normpoints[2][i] + " result:" + normpoints[3][i]);
System.out.println("weight[0]: " + Math.round(weight[0]*100)/100.0 + " ;weight[1]: " +Math.round(weight[1]*100)/100.0 + " ;weight[2]: " + Math.round(weight[2]*100)/100.0 );
System.out.println("guess: "+ guess+ " result " + normpoints[3][i] + " error: " + error);
System.out.println();
}
System.out.println("final weights: x: " + weight[0] + " y: "+ weight[1] + " bias: " +weight[2]);
System.out.println("final weights normalized on y=1: x:" + weight[0]/weight[1] + " y: "+ weight[1]/weight[1] + " bias: " +weight[2]/weight[1]);
}
// lets test if the trained weights classify the test dot on the correct side of the line y=4*x+3
// again 500 random dots with "x", "y" and "results" are created and tested if the NN calculated correct weights
static void testnewPoints(double[] weights)
{
int x;
int y;
double[][] testpoint = new double[5][10000];
double[][] normalizedtestpoint = new double[5][10000];
int quantity = 500;
double sumcheck = 0;
double sumtest = 0;
int correct = 0;
int wrong = 0;
for (int i=0; i<quantity; i++)
{
// calculation of test points with x and y between -100 and 100
x = (int) (200 * Math.random()-100);
y = (int) (200 * Math.random()-100);
testpoint[0][i] = x;
testpoint[1][i] = y;
testpoint[2][i] = 1;
// lets classify the points: at the rights side of the line the result for each point is "1", on the left side "-1"
if (function(x,y) > 0)
testpoint[3][i] = 1;
else
testpoint[3][i] = -1;
// punkt[3] is the result
}
normalizedtestpoint= normalize(testpoint);
// are the test points with our calculated weights classified on the correct side of the line?
for (int i=0; i<quantity; i++)
{
sumcheck = normalizedtestpoint[0][i] * weights[0] + normalizedtestpoint[1][i] * weights[1] + normalizedtestpoint[2][i] * weights[2];
if (sumcheck > 0)
sumtest = 1;
else
sumtest = -1;
if (sumtest == normalizedtestpoint[3][i])
correct++;
else
wrong++;
}
System.out.println("correct: "+ correct + " wrong: " + wrong);
}
}
如果您在我的编码风格中发现一些重大问题,请告诉我,我猜这是一种初学者风格。提前谢谢了!隆科