55

我正在寻找 Java 中的 KDTree 实现。
我做了一个谷歌搜索,结果似乎很随意。实际上有很多结果,但它们大多只是一次性的实现,我宁愿找到更多“生产价值”的东西。诸如 apache 集合或 .NET 的优秀 C5 集合库之类的东西。我可以看到公共错误跟踪器并检查最后一次 SVN 提交发生的时间。此外,在理想情况下,我会为空间数据结构找到一个精心设计的 API,而 KDTree 只是该库中的一个类。

对于这个项目,我只会在 2 维或 3 维中工作,而且我主要只是对一个好的最近邻实现感兴趣。

4

10 回答 10

25

Algorithms in a Nutshell一书中,Java 中有一个 kd 树实现以及一些变体。所有代码都在oreilly.com上,这本书本身也将引导您完成算法,以便您自己构建一个。

于 2008-11-06T15:26:21.647 回答
17

对于未来的寻求者。Java-ml 库有一个运行良好的 kd-tree 实现。 http://java-ml.sourceforge.net/

于 2011-11-14T11:15:58.740 回答
11

我在此处找到的 Levy 教授的实施取得了成功。我意识到您正在寻找更多经过生产认证的实施,因此这可能不太合适。

但是请注意任何路人,我已经在我的照片马赛克项目中使用它一段时间了,没有任何问题。没有保证,但总比没有好:)

于 2010-05-06T20:12:44.960 回答
4

我创建了一个 KD-Tree 实现作为离线反向地理编码库的一部分

https://github.com/AReallyGoodName/OfflineReverseGeocode

于 2014-06-13T12:45:29.000 回答
3

也许来自 Stony-Brook 算法存储库的最近邻搜索KD-trees可以提供帮助。

于 2008-11-03T21:39:33.673 回答
2

这是 KD-Tree 的完整实现,我使用了一些库来存储点和矩形。这些库是免费提供的。可以使用这些类创建自己的类来存储点和矩形。请分享您的反馈。

import java.util.ArrayList;
import java.util.List;
import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
public class KdTree {
    private static class Node {
        public Point2D point; // the point
        public RectHV rect; // the axis-aligned rectangle corresponding to this
        public Node lb; // the left/bottom subtree
        public Node rt; // the right/top subtree
        public int size;
        public double x = 0;
        public double y = 0;
        public Node(Point2D p, RectHV rect, Node lb, Node rt) {
            super();
            this.point = p;
            this.rect = rect;
            this.lb = lb;
            this.rt = rt;
            x = p.x();
            y = p.y();
        }

    }
    private Node root = null;;

    public KdTree() {
    }

    public boolean isEmpty() {
        return root == null;
    }

    public int size() {
        return rechnenSize(root);
    }

    private int rechnenSize(Node node) {
        if (node == null) {
            return 0;
        } else {
            return node.size;
        }
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new NullPointerException();
        }
        if (isEmpty()) {
            root = insertInternal(p, root, 0);
            root.rect = new RectHV(0, 0, 1, 1);
        } else {
            root = insertInternal(p, root, 1);
        }
    }

    // at odd level we will compare x coordinate, and at even level we will
    // compare y coordinate
    private Node insertInternal(Point2D pointToInsert, Node node, int level) {
        if (node == null) {
            Node newNode = new Node(pointToInsert, null, null, null);
            newNode.size = 1;
            return newNode;
        }
        if (level % 2 == 0) {//Horizontal partition line
            if (pointToInsert.y() < node.y) {//Traverse in bottom area of partition
                node.lb = insertInternal(pointToInsert, node.lb, level + 1);
                if(node.lb.rect == null){
                    node.lb.rect = new RectHV(node.rect.xmin(), node.rect.ymin(),
                            node.rect.xmax(), node.y);
                }
            } else {//Traverse in top area of partition
                if (!node.point.equals(pointToInsert)) {
                    node.rt = insertInternal(pointToInsert, node.rt, level + 1);
                    if(node.rt.rect == null){
                        node.rt.rect = new RectHV(node.rect.xmin(), node.y,
                                node.rect.xmax(), node.rect.ymax());
                    }
                }
            }

        } else if (level % 2 != 0) {//Vertical partition line
            if (pointToInsert.x() < node.x) {//Traverse in left area of partition
                node.lb = insertInternal(pointToInsert, node.lb, level + 1);
                if(node.lb.rect == null){
                    node.lb.rect = new RectHV(node.rect.xmin(), node.rect.ymin(),
                            node.x, node.rect.ymax());
                }
            } else {//Traverse in right area of partition
                if (!node.point.equals(pointToInsert)) {
                    node.rt = insertInternal(pointToInsert, node.rt, level + 1);
                    if(node.rt.rect == null){
                        node.rt.rect = new RectHV(node.x, node.rect.ymin(),
                                node.rect.xmax(), node.rect.ymax());
                    }
                }
            }
        }
        node.size = 1 + rechnenSize(node.lb) + rechnenSize(node.rt);
        return node;
    }

    public boolean contains(Point2D p) {
        return containsInternal(p, root, 1);
    }

    private boolean containsInternal(Point2D pointToSearch, Node node, int level) {
        if (node == null) {
            return false;
        }
        if (level % 2 == 0) {//Horizontal partition line
            if (pointToSearch.y() < node.y) {
                return containsInternal(pointToSearch, node.lb, level + 1);
            } else {
                if (node.point.equals(pointToSearch)) {
                    return true;
                }
                return containsInternal(pointToSearch, node.rt, level + 1);
            }
        } else {//Vertical partition line
            if (pointToSearch.x() < node.x) {
                return containsInternal(pointToSearch, node.lb, level + 1);
            } else {
                if (node.point.equals(pointToSearch)) {
                    return true;
                }
                return containsInternal(pointToSearch, node.rt, level + 1);
            }
        }

    }

    public void draw() {
        StdDraw.clear();
        drawInternal(root, 1);
    }

    private void drawInternal(Node node, int level) {
        if (node == null) {
            return;
        }
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.02);
        node.point.draw();
        double sx = node.rect.xmin();
        double ex = node.rect.xmax();
        double sy = node.rect.ymin();
        double ey = node.rect.ymax();
        StdDraw.setPenRadius(0.01);
        if (level % 2 == 0) {
            StdDraw.setPenColor(StdDraw.BLUE);
            sy = ey = node.y;
        } else {
            StdDraw.setPenColor(StdDraw.RED);
            sx = ex = node.x;
        }
        StdDraw.line(sx, sy, ex, ey);
        drawInternal(node.lb, level + 1);
        drawInternal(node.rt, level + 1);
    }

    /**
     * Find the points which lies in the rectangle as parameter
     * @param rect
     * @return
     */
    public Iterable<Point2D> range(RectHV rect) {
        List<Point2D> resultList = new ArrayList<Point2D>();
        rangeInternal(root, rect, resultList);
        return resultList;
    }

    private void rangeInternal(Node node, RectHV rect, List<Point2D> resultList) {
        if (node == null) {
            return;
        }
        if (node.rect.intersects(rect)) {
            if (rect.contains(node.point)) {
                resultList.add(node.point);
            }
            rangeInternal(node.lb, rect, resultList);
            rangeInternal(node.rt, rect, resultList);
        }

    }

    public Point2D nearest(Point2D p) {
        if(root == null){
            return null;
        }
        Champion champion = new Champion(root.point,Double.MAX_VALUE);
        return nearestInternal(p, root, champion, 1).champion;
    }

    private Champion nearestInternal(Point2D targetPoint, Node node,
            Champion champion, int level) {
        if (node == null) {
            return champion;
        }
        double dist = targetPoint.distanceSquaredTo(node.point);
        int newLevel = level + 1;
        if (dist < champion.championDist) {
            champion.champion = node.point;
            champion.championDist = dist;
        }
        boolean goLeftOrBottom = false;
        //We will decide which part to be visited first, based upon in which part point lies.
        //If point is towards left or bottom part, we traverse in that area first, and later on decide
        //if we need to search in other part too.
        if(level % 2 == 0){
            if(targetPoint.y() < node.y){
                goLeftOrBottom = true;
            }
        } else {
            if(targetPoint.x() < node.x){
                goLeftOrBottom = true;
            }
        }
        if(goLeftOrBottom){
            nearestInternal(targetPoint, node.lb, champion, newLevel);
            Point2D orientationPoint = createOrientationPoint(node.x,node.y,targetPoint,level);
            double orientationDist = orientationPoint.distanceSquaredTo(targetPoint);
            //We will search on the other part only, if the point is very near to partitioned line
            //and champion point found so far is far away from the partitioned line.
            if(orientationDist < champion.championDist){
                nearestInternal(targetPoint, node.rt, champion, newLevel);
            }
        } else {
            nearestInternal(targetPoint, node.rt, champion, newLevel);
            Point2D orientationPoint = createOrientationPoint(node.x,node.y,targetPoint,level);
            //We will search on the other part only, if the point is very near to partitioned line
            //and champion point found so far is far away from the partitioned line.
            double orientationDist = orientationPoint.distanceSquaredTo(targetPoint);
            if(orientationDist < champion.championDist){
                nearestInternal(targetPoint, node.lb, champion, newLevel);
            }

        }
        return champion;
    }
    /**
     * Returns the point from a partitioned line, which can be directly used to calculate
     * distance between partitioned line and the target point for which neighbours are to be searched.
     * @param linePointX
     * @param linePointY
     * @param targetPoint
     * @param level
     * @return
     */
    private Point2D createOrientationPoint(double linePointX, double linePointY, Point2D targetPoint, int level){
        if(level % 2 == 0){
            return new Point2D(targetPoint.x(),linePointY);
        } else {
            return new Point2D(linePointX,targetPoint.y());
        }
    }

    private static class Champion{
        public Point2D champion;
        public double championDist;
        public Champion(Point2D c, double d){
            champion = c;
            championDist = d;
        }
    }

    public static void main(String[] args) {
        String filename = "/home/raman/Downloads/kdtree/circle100.txt";
        In in = new In(filename);
        KdTree kdTree = new KdTree();
        while (!in.isEmpty()) {
            double x = in.readDouble();
            double y = in.readDouble();
            Point2D p = new Point2D(x, y);
            kdTree.insert(p);
        }
        // kdTree.print();
        System.out.println(kdTree.size());
        kdTree.draw();
        System.out.println(kdTree.nearest(new Point2D(0.4, 0.5)));
        System.out.println(new Point2D(0.7, 0.4).distanceSquaredTo(new Point2D(0.9,0.5)));
        System.out.println(new Point2D(0.7, 0.4).distanceSquaredTo(new Point2D(0.9,0.4)));

    }
}
于 2017-03-09T21:40:53.453 回答
1

还有JTS 拓扑套件

KdTree 实现只提供范围搜索(没有最近邻)。

如果最近邻居是您的事,请查看STRtree

于 2014-11-05T18:15:17.183 回答
1

你是对的,没有多少网站有 Java 的 kd 实现!无论如何,kd 树基本上是一个二叉搜索树,通常每次都会为该维度计算中值。这是简单的 KDNode,就最近邻方法或完整实现而言,请查看此github项目。这是我能为你找到的最好的一个。希望这对您有所帮助。

private class KDNode {
    KDNode left;
    KDNode right;
    E val;
    int depth;
    private KDNode(E e, int depth){
    this.left = null;
    this.right = null;
    this.val = e;
    this.depth = depth;
}
于 2014-12-14T07:40:40.083 回答
0

可能会引起某人的兴趣。请参阅我最近的()(和 KD Tree 类)在 java 中的 2D 树实现:

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

import java.util.ArrayList;
import java.util.List;

public class KdTree {
    private Node root;
    private int size;

    private static class Node {
        private Point2D p;      // the point
        private RectHV rect;    // the axis-aligned rectangle corresponding to this node
        private Node lb;        // the left/bottom subtree
        private Node rt;        // the right/top subtree
        public Node(Point2D p, RectHV rect) {
            this.p = p;
            this.rect = rect;
        }
    }

    public KdTree() {
    }

    public boolean isEmpty() {
        return size == 0;
    }

    public int size() {
        return size;
    }

    public boolean contains(Point2D p) {
        if (p == null) throw new IllegalArgumentException("argument to contains() is null");
        return contains(root, p, 1);
    }

    private boolean contains(Node node, Point2D p, int level) {
        if (node == null) return false; // a base case for recursive call

        if (node.p.equals(p)) return true;

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            if (p.y() < node.p.y())
                return contains(node.lb, p, level + 1);
            else
                return contains(node.rt, p, level + 1);
        }
        else { // search by x coordinate (node with vertical partition line)
            if (p.x() < node.p.x())
                return contains(node.lb, p, level + 1);
            else
                return contains(node.rt, p, level + 1);
        }
    }

    public void insert(Point2D p) {
        if (p == null) throw new IllegalArgumentException("calls insert() with a null point");
        root = insert(root, p, 1);
    }

    private Node insert(Node x, Point2D p, int level) {
        if (x == null) {
            size++;
            return new Node(p, new RectHV(0, 0, 1, 1));
        }

        if (x.p.equals(p)) return x; // if we try to insert existed point just return its node

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            if (p.y() < x.p.y()) {
                x.lb = insert(x.lb, p, level + 1);
                if (x.lb.rect.equals(root.rect))
                    x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.rect.xmax(), x.p.y());
            }
            else {
                x.rt = insert(x.rt, p, level + 1);
                if (x.rt.rect.equals(root.rect))
                    x.rt.rect = new RectHV(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.rect.ymax());
            }
        }
        else { // search by x coordinate (node with vertical partition line)
            if (p.x() < x.p.x()) {
                x.lb = insert(x.lb, p, level + 1);
                if (x.lb.rect.equals(root.rect))
                    x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.p.x(), x.rect.ymax());
            }
            else {
                x.rt = insert(x.rt, p, level + 1);
                if (x.rt.rect.equals(root.rect))
                    x.rt.rect = new RectHV(x.p.x(), x.rect.ymin(), x.rect.xmax(), x.rect.ymax());
            }
        }
        return x;
    }

    public void draw() {
        draw(root, 1);
    }

    private void draw(Node node, int level) {
        if (node == null) return;

        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        node.p.draw();
        StdDraw.setPenRadius();

        if (level % 2 == 0) {
            StdDraw.setPenColor(StdDraw.BLUE);
            StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y());
        }
        else {
            StdDraw.setPenColor(StdDraw.RED);
            StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax());
        }

        draw(node.lb, level + 1);
        draw(node.rt, level + 1);
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) throw new IllegalArgumentException("calls range() with a null rect");
        List<Point2D> points = new ArrayList<>(); // create an Iterable object with all points we found
        range(root, rect, points); // call helper method with rects intersects comparing
        
        return points; // return an Iterable object (It could be any type - Queue, LinkedList etc)
    }

    private void range(Node node, RectHV rect, List<Point2D> points) {
        if (node == null || !node.rect.intersects(rect)) return; // a base case for recursive call


        if (rect.contains(node.p))
                points.add(node.p);
        range(node.lb, rect, points);
        range(node.rt, rect, points);

    }    

    public Point2D nearest(Point2D query) {
         if (isEmpty()) return null;
        if (query == null) throw new IllegalArgumentException("calls nearest() with a null point");
        // set the start distance from root to query point
        double best = root.p.distanceSquaredTo(query);
        // StdDraw.setPenColor(StdDraw.BLACK); // just for debugging
        // StdDraw.setPenRadius(0.01);
        // query.draw();
        return nearest(root, query, root.p, best, 1); // call a helper method
    }

    private Point2D nearest(Node node, Point2D query, Point2D champ, double best, int level) {
        // a base case for the recursive call
        if (node == null || best < node.rect.distanceSquaredTo(query)) return champ;
        // we'll need to set an actual best distance when we recur
        best = champ.distanceSquaredTo(query);
        // check whether a distance from query point to the traversed node less than
        // distance from current champion to query point
        double temp = node.p.distanceSquaredTo(query);
        if (temp < best) {
            best = temp;
            champ = node.p;
        }

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            // we compare y coordinate and decide go up or down
            if (node.p.y() < query.y()) { // if true go up
                champ = nearest(node.rt, query, champ, best, level + 1);
                // important case - when we traverse node and go back up through the tree
                // we need to decide whether we need to go down(left) in this node or not
                // we just check our bottom (left) node on null && compare distance
                // from query point to the nearest point of the node's rectangle and
                // the distance from current champ point to thr query point
                if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query)) {
                    champ = nearest(node.lb, query, champ, best, level + 1);
                }

            }
            else { // if false go down
                champ = nearest(node.lb, query, champ, best, level + 1);
                if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                    // when we traverse node and go back up through the tree
                    // we need to decide whether we need to go up(right) in this node or not
                    // we just check our top (right) node on null && compare distance
                    // from query point to the nearest point of the node's rectangle and
                    // the distance from current champ point to thr query point
                    champ = nearest(node.rt, query, champ, best, level + 1);

            }

        }
        else {
            // search by x coordinate (node with vertical partition line)
            if (node.p.x() < query.x()) { // if true go right
                champ = nearest(node.rt, query, champ, best, level + 1);
                // the same check as mentioned above when we search by y coordinate
                if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                    champ = nearest(node.lb, query, champ, best, level + 1);
            }
            else { // if false go left
                champ = nearest(node.lb, query, champ, best, level + 1);
                  if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                     champ = nearest(node.rt, query, champ, best, level + 1);
            }
        }
        return champ;
    }



    public static void main(String[] args) {
        // unit tests
        KdTree kd = new KdTree();
        Point2D p1 = new Point2D(0.7, 0.2);
        Point2D p2 = new Point2D(0.5, 0.4);
        Point2D p3 = new Point2D(0.2, 0.3);
        Point2D p4 = new Point2D(0.4, 0.7);
        Point2D p5 = new Point2D(0.9, 0.6);
        // Point2D query = new Point2D(0.676, 0.736);
        Point2D query1 = new Point2D(0.972, 0.887);
        // RectHV test = new RectHV(0, 0, 0.7, 0.4);
        // Point2D query = new Point2D(0.331, 0.762);

        // Point2D p6 = new Point2D(0.4, 0.4);
        // Point2D p7 = new Point2D(0.1, 0.6);
        // RectHV rect = new RectHV(0.05, 0.1, 0.15, 0.6);

        kd.insert(p1);
        kd.insert(p2);
        kd.insert(p3);
        kd.insert(p4);
        kd.insert(p5);
        System.out.println(kd.nearest(query1));
        // System.out.println("Dist query to 0.4,0.7= " + query.distanceSquaredTo(p4));
        // System.out.println("Dist query to RectHV 0.2,0,3= " + test.distanceSquaredTo(p4));
        // kd.insert(p6);
        // kd.insert(p7);
        // System.out.println(kd.size);
        // System.out.println(kd.contains(p3));
        // // System.out.println(kd.range(rect));

        kd.draw();
        

    }
}
于 2022-02-28T16:36:04.663 回答
-1
package kdtree;

class KDNode{
    KDNode left;
    KDNode right;
    int []data;

    public KDNode(){
        left=null;
        right=null;
    }

    public KDNode(int []x){
        left=null;
        right=null;
        data = new int[2];
        for (int k = 0; k < 2; k++)
            data[k]=x[k];
    }
}
class KDTreeImpl{
    KDNode root;
    int cd=0;
    int DIM=2;

    public KDTreeImpl() {
        root=null;
    }

    public boolean isEmpty(){
        return root == null;
    }

    public void insert(int []x){
        root = insert(x,root,cd);
    }
    private KDNode insert(int []x,KDNode t,int cd){
        if (t == null)
            t = new KDNode(x);
        else if (x[cd] < t.data[cd])
            t.left = insert(x, t.left, (cd+1)%DIM);
        else
            t.right = insert(x, t.right, (cd+1)%DIM);
        return t;
    }

    public boolean search(int []data){
        return search(data,root,0);
    }

    private boolean search(int []x,KDNode t,int cd){
        boolean found=false;
        if(t==null){
            return false;
        }
        else {
            if(x[cd]==t.data[cd]){
                if(x[0]==t.data[0] && x[1]==t.data[1]) 
                return true;
            }else if(x[cd]<t.data[cd]){
                found = search(x,t.left,(cd+1)%DIM);
            }else if(x[cd]>t.data[cd]){
                found = search(x,t.right,(cd+1)%DIM);
            }
            return found;
        }
    }

    public void inorder(){
        inorder(root);
    }
    private void inorder(KDNode r){
        if (r != null){
            inorder(r.left);
            System.out.print("("+r.data[0]+","+r.data[1] +") ");
            inorder(r.right);
        }
    }
    public void preorder() {
        preorder(root);
    }
    private void preorder(KDNode r){
        if (r != null){
            System.out.print("("+r.data[0]+","+r.data[1] +") ");
            preorder(r.left);             
            preorder(r.right);
        }
    }
    /* Function for postorder traversal */
    public void postorder() {
        postorder(root);
    }
    private void postorder(KDNode r) {
        if (r != null){
            postorder(r.left);             
            postorder(r.right);
            System.out.print("("+r.data[0]+","+r.data[1] +") ");
        }
    }
}
public class KDTree {

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args) {
        // TODO code application logic here
        KDTreeImpl kdt = new KDTreeImpl();
        int x[] = new int[2];
        x[0] = 30;
        x[1] = 40;
        kdt.insert(x);

        x[0] = 5;
        x[1] = 25;
        kdt.insert(x);

        x[0] = 10;
        x[1] = 12;
        kdt.insert(x);

        x[0] = 70;
        x[1] = 70;
        kdt.insert(x);

        x[0] = 50;
        x[1] = 30;
        kdt.insert(x);
        System.out.println("Input Elements");
        System.out.println("(30,40) (5,25) (10,12) (70,70) (50,30)\n\n");
        System.out.println("Printing KD Tree in Inorder");
        kdt.inorder();
        System.out.println("\nPrinting KD Tree in PreOder");
        kdt.preorder();
        System.out.println("\nPrinting KD Tree in PostOrder");
        kdt.postorder();
        System.out.println("\nsearching...............");
        x[0]=40;x[1]=40;
        System.out.println(kdt.search(x));
    }
}
于 2016-12-31T17:25:09.840 回答