我正在尝试构建一个简单的模型,可以将点分类为2D 空间的 2 个分区:
- 我通过指定几个点和它们所属的分区来训练模型。
- 我使用模型来预测测试点可能落入的组(分类) 。
不幸的是,我没有得到预期的答案。我在我的代码中遗漏了什么还是我做错了什么?
public class SimpleClassifier {
public static class Point{
public int x;
public int y;
public Point(int x,int y){
this.x = x;
this.y = y;
}
@Override
public boolean equals(Object arg0) {
Point p = (Point) arg0;
return( (this.x == p.x) &&(this.y== p.y));
}
@Override
public String toString() {
// TODO Auto-generated method stub
return this.x + " , " + this.y ;
}
}
public static void main(String[] args) {
Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();
points.put(new Point(0,0), 0);
points.put(new Point(1,1), 0);
points.put(new Point(1,0), 0);
points.put(new Point(0,1), 0);
points.put(new Point(2,2), 0);
points.put(new Point(8,8), 1);
points.put(new Point(8,9), 1);
points.put(new Point(9,8), 1);
points.put(new Point(9,9), 1);
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
learningAlgo = new OnlineLogisticRegression(2, 2, new L1());
learningAlgo.learningRate(50);
//learningAlgo.alpha(1).stepOffset(1000);
System.out.println("training model \n" );
for(Point point : points.keySet()){
Vector v = getVector(point);
System.out.println(point + " belongs to " + points.get(point));
learningAlgo.train(points.get(point), v);
}
learningAlgo.close();
//now classify real data
Vector v = new RandomAccessSparseVector(2);
v.set(0, 0.5);
v.set(1, 0.5);
Vector r = learningAlgo.classifyFull(v);
System.out.println(r);
System.out.println("ans = " );
System.out.println("no of categories = " + learningAlgo.numCategories());
System.out.println("no of features = " + learningAlgo.numFeatures());
System.out.println("Probability of cluster 0 = " + r.get(0));
System.out.println("Probability of cluster 1 = " + r.get(1));
}
public static Vector getVector(Point point){
Vector v = new DenseVector(2);
v.set(0, point.x);
v.set(1, point.y);
return v;
}
}
输出:
ans =
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722
99% 的时间输出显示cluster 1
. 为什么?