3

我已经在 J​​ava 中实现了带有梯度下降的逻辑回归。好像效果不太好(没有正确分类记录;y=1的概率很大。)不知道我的实现是否正确。代码我已经翻了好几遍了,还是不行找到任何错误。我一直在关注 Andrew Ng 在 Course Era 上的机器学习教程。我的 Java 实现有 3 个类。即:

  1. DataSet.java : 读取数据集
  2. Instance.java:有两个成员:1. double[] x 和 2. double label
  3. Logistic.java :这是使用梯度下降实现逻辑回归的主要类。

这是我的成本函数:

J(Θ) = (- 1/m ) [Σ m i=1 y (i) log( h Θ ( x (i) ) ) + (1 - y (i) ) log(1 - h Θ (x (一世)))]

对于上面的成本函数,这是我的梯度下降算法:

重复 (

Θ j := Θ j - α Σ m i=1 ( h Θ ( x (i) ) - y (i) ) x (i) j

(同时更新所有 Θ j

)

import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class Logistic {

    /** the learning rate */
    private double alpha;

    /** the weight to learn */
    private double[] theta;

    /** the number of iterations */
    private int ITERATIONS = 3000;

    public Logistic(int n) {
        this.alpha = 0.0001;
        theta = new double[n];
    }

    private double sigmoid(double z) {
        return (1 / (1 + Math.exp(-z)));
    }

    public void train(List<Instance> instances) {

    double[] temp = new double[3];

    //Gradient Descent algorithm for minimizing theta
    for(int i=1;i<=ITERATIONS;i++)
    {
       for(int j=0;j<3;j++)
       {      
        temp[j]=theta[j] - (alpha * sum(j,instances));
       }

       //simulataneous updates of theta  
       for(int j=0;j<3;j++)
       {
         theta[j] = temp[j];
       }
        System.out.println(Arrays.toString(theta));
    }

    }

    private double sum(int j,List<Instance> instances)
    {
        double[] x;
        double prediction,sum=0,y;


       for(int i=0;i<instances.size();i++)
       {
          x = instances.get(i).getX();
          y = instances.get(i).getLabel();
          prediction = classify(x);
          sum+=((prediction - y) * x[j]);
       }
         return (sum/instances.size());

    }

    private double classify(double[] x) {
        double logit = .0;
        for (int i=0; i<theta.length;i++)  {
            logit += (theta[i] * x[i]);
        }
        return sigmoid(logit);
    }


    public static void main(String... args) throws FileNotFoundException {

      //DataSet is a class with a static method readDataSet which reads the dataset
      // Instance is a class with two members: double[] x, double label y
      // x contains the features and y is the label.

        List<Instance> instances = DataSet.readDataSet("data.txt");
      // 3 : number of theta parameters corresponding to the features x 
      // x0 is always 1   
        Logistic logistic = new Logistic(3);
        logistic.train(instances);

        //Test data
        double[]x = new double[3];
        x[0]=1;
        x[1]=45;
        x[2] = 85;

        System.out.println("Prob: "+logistic.classify(x));


    }
}

谁能告诉我我做错了什么?提前致谢!:)

4

1 回答 1

1

在研究逻辑回归时,我花时间详细查看了您的代码。

TLDR

事实上,看起来这个算法是正确的。

我认为,你有这么多假阴性或假阳性的原因是因为你选择的超参数。

该模型训练不足,因此假设欠拟合。

细节

我不得不创建DataSetInstance类,因为您没有发布它们,并基于 Cryotherapy 数据集设置了一个训练数据集和一个测试数据集。见http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+

然后,使用您相同的精确代码(用于逻辑回归部分)并通过选择 alpha 率0.001和迭代次数100000,我80.64516129032258在测试数据集上获得了百分比的准确率,这还不错。

我试图通过手动调整这些超参数来获得更好的精确率,但无法获得更好的结果。

我想,在这一点上,一个增强将是实现正则化。

梯度下降公式

1/m在 Andrew Ng 关于成本函数和梯度下降的视频中,省略该术语是正确的。一种可能的解释是该1/m术语包含在该alpha术语中。或者也许这只是一个疏忽。请参阅https://www.youtube.com/watch?v=TTdcc21Ko9A&index=36&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=6m53s 6m53s。

但是,如果您观看 Andrew Ng 关于正则化和逻辑回归的视频,您会注意到该术语1/m清楚地出现在公式中。请参阅https://www.youtube.com/watch?v=IXPgm1e0IOo&index=42&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=2m19s 2m19s。

于 2019-01-24T09:44:44.390 回答