我已经在 Java 中实现了带有梯度下降的逻辑回归。好像效果不太好(没有正确分类记录;y=1的概率很大。)不知道我的实现是否正确。代码我已经翻了好几遍了,还是不行找到任何错误。我一直在关注 Andrew Ng 在 Course Era 上的机器学习教程。我的 Java 实现有 3 个类。即:
- DataSet.java : 读取数据集
- Instance.java:有两个成员:1. double[] x 和 2. double label
- Logistic.java :这是使用梯度下降实现逻辑回归的主要类。
这是我的成本函数:
J(Θ) = (- 1/m ) [Σ m i=1 y (i) log( h Θ ( x (i) ) ) + (1 - y (i) ) log(1 - h Θ (x (一世)))]
对于上面的成本函数,这是我的梯度下降算法:重复 (
Θ j := Θ j - α Σ m i=1 ( h Θ ( x (i) ) - y (i) ) x (i) j
(同时更新所有 Θ j))
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Logistic {
/** the learning rate */
private double alpha;
/** the weight to learn */
private double[] theta;
/** the number of iterations */
private int ITERATIONS = 3000;
public Logistic(int n) {
this.alpha = 0.0001;
theta = new double[n];
}
private double sigmoid(double z) {
return (1 / (1 + Math.exp(-z)));
}
public void train(List<Instance> instances) {
double[] temp = new double[3];
//Gradient Descent algorithm for minimizing theta
for(int i=1;i<=ITERATIONS;i++)
{
for(int j=0;j<3;j++)
{
temp[j]=theta[j] - (alpha * sum(j,instances));
}
//simulataneous updates of theta
for(int j=0;j<3;j++)
{
theta[j] = temp[j];
}
System.out.println(Arrays.toString(theta));
}
}
private double sum(int j,List<Instance> instances)
{
double[] x;
double prediction,sum=0,y;
for(int i=0;i<instances.size();i++)
{
x = instances.get(i).getX();
y = instances.get(i).getLabel();
prediction = classify(x);
sum+=((prediction - y) * x[j]);
}
return (sum/instances.size());
}
private double classify(double[] x) {
double logit = .0;
for (int i=0; i<theta.length;i++) {
logit += (theta[i] * x[i]);
}
return sigmoid(logit);
}
public static void main(String... args) throws FileNotFoundException {
//DataSet is a class with a static method readDataSet which reads the dataset
// Instance is a class with two members: double[] x, double label y
// x contains the features and y is the label.
List<Instance> instances = DataSet.readDataSet("data.txt");
// 3 : number of theta parameters corresponding to the features x
// x0 is always 1
Logistic logistic = new Logistic(3);
logistic.train(instances);
//Test data
double[]x = new double[3];
x[0]=1;
x[1]=45;
x[2] = 85;
System.out.println("Prob: "+logistic.classify(x));
}
}
谁能告诉我我做错了什么?提前致谢!:)