我正在尝试实现与DBSCAN一起使用的KD 树。问题是我需要找到满足距离标准的所有点的所有邻居。问题是当我在我的实现中使用该方法时,我在使用朴素搜索(这是所需的输出)时没有得到相同的输出。我的实现改编自python 实现。这是我到目前为止所得到的:nearestNeighbours
//Point.java
package dbscan_gui;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
public class Point {
final HashSet<Point> neighbours = new HashSet<Point>();
int[] points;
boolean visited = false;
public Point(int... is) {
this.points = is;
}
public String toString() {
return Arrays.toString(points);
}
public double squareDistance(Point p) {
double sum = 0;
for (int i = 0;i < points.length;i++) {
sum += Math.pow(points[i] - p.points[i],2);
}
return sum;
}
public double distance(Point p) {
return Math.sqrt(squareDistance(p));
}
public void addNeighbours(ArrayList<Point> ps) {
neighbours.addAll(ps);
}
public void addNeighbour(Point p) {
if (p != this)
neighbours.add(p);
}
}
//KDTree.java
package dbscan_gui;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;
public class KDTree {
KDTreeNode root;
PointComparator[] comps;
public KDTree(ArrayList<Point> list) {
int axes = list.get(0).points.length;
comps = new PointComparator[axes];
for (int i = 0; i < axes; i++) {
comps[i] = new PointComparator(i);
}
root = new KDTreeNode(list,0);
}
private class PointComparator implements Comparator<Point> {
private int axis;
public PointComparator(int axis) {
this.axis = axis;
}
@Override
public int compare(Point p1, Point p2) {
return p1.points[axis] - p2.points[axis];
}
}
/**
* Adapted from https://code.google.com/p/python-kdtree/
* Stores points in a tree, sorted by axis
*/
public class KDTreeNode {
KDTreeNode leftChild = null;
KDTreeNode rightChild = null;
Point location;
public KDTreeNode(ArrayList<Point> list, int depth) {
if(list.isEmpty())
return;
final int axis = depth % (list.get(0).points.length);
Collections.sort(list, comps[axis] );
int median = list.size()/2;
location = list.get(median);
List<Point> leftPoints = list.subList(0, median);
List<Point> rightPoints = list.subList(median+1, list.size());
if(!leftPoints.isEmpty())
leftChild = new KDTreeNode(new ArrayList<Point>(leftPoints), depth+1);
if(!rightPoints.isEmpty())
rightChild = new KDTreeNode(new ArrayList<Point>(rightPoints),depth+1);
}
/**
* @return true if this node has no children
*/
public boolean isLeaf() {
return leftChild == null && rightChild == null;
}
}
/**
* Finds the nearest neighbours of a point that fall within a given distance
* @param queryPoint the point to find the neighbours of
* @param epsilon the distance threshold
* @return the list of points
*/
public ArrayList<Point> nearestNeighbours(Point queryPoint, int epsilon) {
KDNeighbours neighbours = new KDNeighbours(queryPoint);
nearestNeighbours_(root, queryPoint, 0, neighbours);
return neighbours.getBest(epsilon);
}
/**
* @param node
* @param queryPoint
* @param depth
* @param bestNeighbours
*/
private void nearestNeighbours_(KDTreeNode node, Point queryPoint, int depth, KDNeighbours bestNeighbours) {
if(node == null)
return;
if(node.isLeaf()) {
bestNeighbours.add(node.location);
return;
}
int axis = depth % (queryPoint.points.length);
KDTreeNode nearSubtree = node.rightChild;
KDTreeNode farSubtree = node.leftChild;
if(queryPoint.points[axis] < node.location.points[axis]) {
nearSubtree = node.leftChild;
farSubtree = node.rightChild;
}
nearestNeighbours_(nearSubtree, queryPoint, depth+1, bestNeighbours);
if(node.location != queryPoint)
bestNeighbours.add(node.location);
if(Math.pow(node.location.points[axis] - queryPoint.points[axis],2) <= bestNeighbours.largestDistance)
nearestNeighbours_(farSubtree, queryPoint, depth+1,bestNeighbours);
return;
}
/**
* Private datastructure for holding the neighbours of a point
*/
private class KDNeighbours {
Point queryPoint;
double largetsDistance = 0;
TreeSet<Tuple> currentBest = new TreeSet<Tuple>(new Comparator<Tuple>() {
@Override
public int compare(Tuple o1, Tuple o2) {
return (int) (o1.y-o2.y);
}
});
KDNeighbours(Point queryPoint) {
this.queryPoint = queryPoint;
}
public ArrayList<Point> getBest(int epsilon) {
ArrayList<Point> best = new ArrayList<Point>();
Iterator<Tuple> it = currentBest.iterator();
while(it.hasNext()) {
Tuple t =it.next();
if(t.y > epsilon*epsilon)
break;
else if(t.x != queryPoint)
best.add(t.x);
}
return best;
}
public void add(Point p) {
currentBest.add(new Tuple(p, p.squareDistance(queryPoint)));
largestDistance = currentBest.last().y;
}
private class Tuple {
Point x;
double y;
Tuple(Point x, double y) {
this.x = x;
this.y = y;
}
}
}
public static void main(String[] args) {
int epsilon = 3;
System.out.println("Epsilon: "+epsilon);
ArrayList<Point> points = new ArrayList<Point>();
Random r = new Random();
for (int i = 0; i < 10; i++) {
points.add(new Point(r.nextInt(10), r.nextInt(10)));
}
System.out.println("Points "+points );
System.out.println("----------------");
System.out.println("Neighbouring Kd");
KDTree tree = new KDTree(points);
for (Point p : points) {
ArrayList<Point> neighbours = tree.nearestNeighbours(p, epsilon);
for (Point q : neighbours) {
q.addNeighbour(p);
}
p.addNeighbours(neighbours);
p.printNeighbours();
p.neighbours.clear();
}
System.out.println("------------------");
System.out.println("Neighbouring O(n^2)");
for (int i = 0; i < points.size(); i++) {
for (int j = i + 1; j < points.size(); j++) {
Point p = points.get(i), q = points.get(j);
if (p.distance(q) <= epsilon) {
p.addNeighbour(q);
q.addNeighbour(p);
}
}
}
for (Point point : points) {
point.printNeighbours();
}
}
}
当我运行它时,我得到以下输出(后半部分是模型输出):
Epsilon: 3
Points [[9, 5], [4, 7], [3, 1], [0, 0], [5, 7], [0, 1], [5, 5], [1, 2], [9, 2], [9, 9]]
----------------
Neighbouring Kd
Neighbours of [0, 0] are: [[0, 1]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 7]]
Neighbours of [5, 7] are: [[4, 7]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
------------------
Neighbouring O(n^2)
Neighbours of [0, 0] are: [[0, 1], [1, 2]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [0, 0], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 5], [5, 7]]
Neighbours of [5, 7] are: [[4, 7], [5, 5]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
我想不通为什么邻居不一样,似乎可以发现a->b是邻居,但不是b->a也是邻居。