25

我正在尝试在 Java 中实现 Dijkstra 的算法(自学)。我使用维基百科提供的伪代码(链接)。现在接近算法的结尾,我应该decrease-key v in Q;。我想我应该用 BinaryHeap 或类似的东西实现 Q ?在这里使用的正确(内置)数据类型是什么?

private void dijkstra(int source) {
        int[] dist = new int[this.adjacencyMatrix.length];
        int[] previous = new int[this.adjacencyMatrix.length];
        Queue<Integer> q = new LinkedList<Integer>();

        for (int i = 0; i < this.adjacencyMatrix.length; i++) {
            dist[i] = this.INFINITY;
            previous[i] = this.UNDEFINED;
            q.add(i);
        }

        dist[source] = 0;

        while(!q.isEmpty()) {
            // get node with smallest dist;
            int u = 0;
            for(int i = 0; i < this.adjacencyMatrix.length; i++) {
                if(dist[i] < dist[u])
                    u = i;
            }

            // break if dist == INFINITY
            if(dist[u] == this.INFINITY) break;

            // remove u from q
            q.remove(u);

            for(int i = 0; i < this.adjacencyMatrix.length; i++) {
                if(this.adjacencyMatrix[u][i] == 1) {
                    // in a unweighted graph, this.adjacencyMatrix[u][i] always == 1;
                    int alt = dist[u] + this.adjacencyMatrix[u][i]; 
                    if(alt < dist[i]) {
                        dist[i] = alt;
                        previous[i] = u;

                        // here's where I should "decrease the key"
                    }
                }
            }
        }
    }
4

5 回答 5

39

最简单的方法是使用优先级队列,而不关心优先级队列中之前添加的键。这意味着您将在队列中多次使用每个节点,但这根本不会损害算法。如果你看一下,所有被替换的节点版本将在稍后被拾取,到那时已经确定了最近的距离。

来自维基百科的检查if alt < dist[v]:是这项工作的原因。运行时只会因此而降级一点,但如果您需要非常快的版本,则必须进一步优化。

笔记:

像任何优化一样,这个优化应该小心处理,并且可能会导致好奇和难以发现的错误(参见例如here)。在大多数情况下,只使用 remove 和 re-insert 应该没问题,但是我在这里提到的技巧可以稍微加快你的代码,如果你的 Dijkstra 实现是瓶颈。

最重要的是:在尝试之前,请确保您的优先级队列如何处理优先级。队列中的实际优先级永远不会改变,否则您可能会弄乱队列的不变量,这意味着存储在队列中的项目可能无法再检索。例如,在 Java 中,优先级与对象一起存储,因此您确实需要一个额外的包装器:

这将不起作用:

import java.util.PriorityQueue;

// Store node information and priority together
class Node implements Comparable<Node> {
  public int prio;
  public Node(int prio) { this.prio = prio; }

  public int compareTo(Node n) {
     return Integer.compare(this.prio, n.prio);
  }
}

...
...
PriorityQueue<Node> q = new PriorityQueue<Node>();
n = new Node(10);
q.add(n)
...
// let's update the priority
n.prio = 0;
// re-add
q.add(n);
// q may be broken now

因为n.prio=0您也在更改队列中对象的优先级。但是,这将正常工作:

import java.util.PriorityQueue;

// Only node information
class Node {
  // Whatever you need for your graph
  public Node() {}
}

class PrioNode {
   public Node n;
   public int prio;
   public PrioNode(Node n, int prio) {
     this.n = n;
     this.prio = prio;
   }

   public int compareTo(PrioNode p) {
      return Integer.compare(this.prio, p.prio);
   }
}

...
...
PriorityQueue<PrioNode> q = new PriorityQueue<PrioNode>();
n = new Node();
q.add(new PrioNode(n,10));
...
// let's update the priority and re-add
q.add(new PrioNode(n,0));
// Everything is fine, because we have not changed the value
// in the queue.
于 2011-06-07T15:28:08.317 回答
22

您可以使用TreeSet, (在 C++ 中您可以使用 a std::set)来实现 Dijkstra 的优先级队列。ATreeSet代表一个集合,但我们也允许描述集合中项目的顺序。您需要将节点存储在集合中并使用节点的距离对节点进行排序。距离最小的节点将位于集合的开头。

class Node {
    public int id;   // store unique id to distinguish elements in the set
    public int dist; // store distance estimates in the Node itself
    public int compareTo(Node other) {
        // TreeSet implements the Comparable interface.
        // We need tell Java how to compare two Nodes in the TreeSet.
        // For Dijstra, we want the node with the _smallest_ distance
        // to be at the front of the set.
        // If two nodes have same distance, choose node with smaller id
        if (this.dist == other.dist) {
            return Integer.valueOf(this.id).compareTo(other.id);
        } else {
            return Integer.valueOf(this.dist).compareTo(other.dist);
        }
    }
}
// ...
TreeSet<Node> nodes = new TreeSet<Node>();

extract-min操作通过以下方式实现,最坏情况时间为 O(lgn):

Node u = nodes.pollFirst();

通过减少键操作,我们删除具有旧键的节点(旧的距离估计)并添加具有较小键的新节点(新的、更好的距离估计)。两种操作都需要 O(lgn) 最坏情况时间。

nodes.remove(v);
v.dist = /* some smaller key */
nodes.add(v);

一些额外的说明:

  • 上面的实现很简单,因为这两个操作都是n的对数,总的来说,运行时间将是O((n + e)lgn)。这对于 Dijkstra 的基本实现被认为是有效的。有关这种复杂性的证明,请参阅CLRS 书ISBN:978-0-262-03384-8 )第 19 章。
  • 尽管大多数教科书都会对 Dijkstra、Prim、A* 等使用优先级队列,但不幸的是,Java 和 C++ 实际上都没有实现具有相同 O(lgn) 减少键操作的优先级队列!

  • PriorityQueueJava 中确实存在,但该remove(Object o)方法不是对数的,因此您的减少键操作将是 O(n) 而不是 O(lgn) 并且(渐近地)您会得到一个较慢的 Dikjstra!

  • 要从无到有(使用 for 循环)构建 TreeSet,需要时间 O(nlgn),与 O(n) 最坏情况时间相比,从 n 个项目初始化堆/优先级队列要慢一些。然而,Dijkstra 的主循环需要时间 O(nlgn + elgn),这在初始化时间中占主导地位。所以对于 Dijkstra 来说,初始化 TreeSet 不会导致任何明显的减速。

  • 我们不能使用 aHashSet因为我们关心键的顺序——我们希望能够先取出最小的。这为我们提供了具有最佳距离估计的节点!

  • TreeSet在 Java 中是使用红黑树实现的——一种自平衡二叉搜索树。这就是为什么这些操作具有对数最坏情况时间。

  • 您正在使用ints 来表示您的图形节点,这很好,但是当您引入一个Node类时,您需要一种方法来关联这两个实体。我建议HashMap<Integer, Node>您在构建图表时构建一个 - 这将有助于跟踪哪个int对应于什么Node。`

于 2017-02-24T16:10:15.177 回答
3

建议PriorityQueue不提供减键操作。但是,可以通过删除元素然后用新键重新插入来模拟它。这不应该增加算法的渐近运行时间,尽管它可以通过内置支持稍微快一点。

编辑:这确实增加了渐近运行时间,因为 reduce-key 预计O(log n)用于堆remove(Object)O(n). Java 中似乎没有任何内置的优先级队列支持有效的减少键。

于 2011-06-07T15:08:00.670 回答
1

根据 wiki 文章的优先级队列。这表明现在的经典实现是使用“由斐波那契堆实现的最小优先级队列”。

于 2011-06-07T14:59:44.107 回答
1

是的,Java 没有通过 PriorityQueue 为 Min Heap 提供减少键,因此删除操作将是 O(N),可以优化为 logN。

我已经用 reductionKey 实现了 Min Heap(实际上是 reductionKey 和 increaseKey 但这里只有 reductionKey 就足够了)。实际数据结构是最小堆映射(HashMap 存储所有节点的索引,并有助于通过当前顶点更新当前顶点的邻居的最小路径值)

我用泛型实现了优化的解决方案,我花了大约 3-4 个小时来编码(我的第一次),时间复杂度是 O(logV.E)

希望这会有所帮助!

 package algo;

 import java.util.*;

 public class Dijkstra {

/*
 *
 * @author nalin.sharma
 *
 */

/**
 *
 * Heap Map Data Structure
 * Heap stores the Nodes with their weight based on comparison on Node's weight
 * HashMap stores the Node and its index for O(1) look up of node in heap
 *
 *
 *
 *
 * Example -:
 *
 *                                   7
 *                         [2]----------------->[4]
 *                       ^  |                   ^ \
 *                     /   |                   |   \ 1
 *                 2 /    |                   |     v
 *                 /     |                   |       [6]
 *               /      | 1               2 |       ^
 *             /       |                   |      /
 *          [1]       |                   |     /
 *             \     |                   |    / 5
 *            4 \   |                   |   /
 *               v v                   |  /
 *                [3]---------------->[5]
 *                         3
 *
 *        Minimum distance from source 1
 *         v  | d[v] | path
 *         --- ------  ---------
 *         2 |  2  |  1,2
 *         3 |  3  |  1,2,3
 *         4 |  8  |  1,2,3,5,4
 *         5 |  6  |  1,2,3,5
 *         6 |  9  |  1,2,3,4,6
 *
 *
 *
 *     Below is the Implementation -:
 *
 */

static class HeapMap<T> {
    int size, ind = 0;
    NodeWeight<T> arr [];
    Map<T,Integer> map;

    /**
     *
     * @param t is the Node(1,2,3..or A,B,C.. )
     * @return the index of element in min heap
     */
    int index(T t) {
        return map.get(t);
    }

    /**
     *
     * @param index is the Node(1,2,3..or A,B,C.. )
     * @return Node and its Weight
     */
    NodeWeight<T> node(int index) {
        return arr[index];
    }

    /**
     *
     * @param <T> Node of type <T> and its weight
     */
    static class NodeWeight<T> {
        NodeWeight(T v, int w) {
            nodeVal = v;
            weight = w;
        }
        T nodeVal;
        int weight;
        List<T> path = new ArrayList<>();
    }

    public HeapMap(int s) {
        size = s;
        arr = new NodeWeight[size + 1];
        map = new HashMap<>();
    }

    private void updateIndex(T key, int newInd) {
        map.put(key, newInd);
    }

    private void shiftUp(int i) {
        while(i > 1) {
            int par = i / 2;
            NodeWeight<T> currNodeWeight = arr[i];
            NodeWeight<T> parentNodeWeight = arr[par];
            if(parentNodeWeight.weight > currNodeWeight.weight) {
                updateIndex(parentNodeWeight.nodeVal, i);
                updateIndex(currNodeWeight.nodeVal, par);
                swap(par,i);
                i = i/2;
            }
            else {
                break;
            }
        }
    }

    /**
     *
     * @param nodeVal
     * @param newWeight
     * Based on if the value introduced is higher or lower shift down or shift up operations are performed
     *
     */
    public void update(T nodeVal, int newWeight) {
        int i = ind;
        NodeWeight<T> nodeWeight = arr[map.get(nodeVal)];
        int oldWt = nodeWeight.weight;
        nodeWeight.weight = newWeight;
        if(oldWt < newWeight) {
            shiftDown(map.get(nodeVal));
        }
        else if(oldWt > newWeight) {
            shiftUp(map.get(nodeVal));
        }
    }

    /**
     *
     * @param nodeVal
     * @param wt
     *
     * Typical insertion in Min Heap and storing its element's indices in HashMap for fast lookup
     */
    public void insert(T nodeVal, int wt) {
        NodeWeight<T> nodeWt = new NodeWeight<>(nodeVal, wt);
        arr[++ind] = nodeWt;
        updateIndex(nodeVal, ind);
        shiftUp(ind);
    }

    private void swap(int i, int j) {
        NodeWeight<T> tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

    private void shiftDown(int i) {
        while(i <= ind) {
            int current = i;
            int lChild = i * 2;
            int rChild = i * 2 + 1;
            if(rChild <= ind) {
                int childInd = (arr[lChild].weight < arr[rChild].weight) ? lChild : rChild;
                if(arr[childInd].weight < arr[i].weight) {
                    updateIndex(arr[childInd].nodeVal, i);
                    updateIndex(arr[i].nodeVal, childInd);
                    swap(childInd, i);
                    i = childInd;
                }
            }
            else if(lChild <= ind && arr[lChild].weight < arr[i].weight) {
                updateIndex(arr[lChild].nodeVal, i);
                updateIndex(arr[i].nodeVal, lChild);
                swap(lChild, i);
                i = lChild;
            }
            if(current == i) {
                break;
            }
        }
    }

    /**
     *
     * @return
     *
     * Typical deletion in Min Heap and removing its element's indices in HashMap
     *
     */
    public NodeWeight<T> remove() {
        if(ind == 0) {
            return null;
        }
        map.remove(arr[1].nodeVal);
        NodeWeight<T> out = arr[1];
        out.path.add(arr[1].nodeVal);
        arr[1] = arr[ind];
        arr[ind--] = null;
        if(ind > 0) {
            updateIndex(arr[1].nodeVal, 1);
            shiftDown(1);
        }
        return out;
    }
}

/**
 *
 *  Graph representation -: It is an Map(T,Node<T>) of Map(T(neighbour), Integer(Edge's weight))
 *
 */
static class Graph<T> {

    void init(T... t) {
        for(T z: t) {
            nodes.put(z, new Node<>(z));
        }
    }

    public Graph(int s, T... t) {
        size = s;
        nodes = new LinkedHashMap<>(size);
        init(t);
    }

    /**
     *
     * Node class
     *
     */
    static class Node<T> {
        Node(T v) {
            val = v;
        }
        T val;
        //Set<Edge> edges = new HashSet<>();
        Map<T, Integer> edges = new HashMap<>();
    }

    /*static class Edge {
        Edge(int to, int w) {
            target = to;
            wt = w;
        }
        int target;
        int wt;
        }
    }*/

    int size;

    Map<T, Node<T>> nodes;

    void addEdge(T from, T to, int wt) {
        nodes.get(from).edges.put(to, wt);
    }
}

/**
 *
 * @param graph
 * @param from
 * @param heapMap
 * @param <T>
 *
 * Performs initialisation of all the nodes from the start vertex
 *
 */
    private static <T> void init(Graph<T> graph, T from, HeapMap<T> heapMap) {
    Graph.Node<T> fromNode = graph.nodes.get(from);
    graph.nodes.forEach((k,v)-> {
            if(from != k) {
                heapMap.insert(k, fromNode.edges.getOrDefault(k, Integer.MAX_VALUE));
            }
        });
    }


static class NodePathMinWeight<T> {
    NodePathMinWeight(T n, List<T> p, int c) {
        node = n;
        path = p;
        minCost= c;
    }
    T node;
    List<T> path;
    int minCost;
}

/**
 *
 * @param graph
 * @param from
 * @param <T>
 * @return
 *
 * Repeat the below process for all the vertices-:
 * Greedy way of picking the current shortest distance and updating its neighbors distance via this vertex
 *
 * Since Each Vertex V has E edges, the time Complexity is
 *
 * O(V.logV.E)
 * 1. selecting vertex with shortest distance from source in logV time -> O(logV) via Heap Map Data structure
 * 2. Visiting all E edges of this vertex and updating the path of its neighbors if found less via this this vertex. -> O(E)
 * 3. Doing operation step 1 and step 2 for all the vertices -> O(V)
 *
 */

    static <T> Map<T,NodePathMinWeight<T>> dijkstra(Graph<T> graph, T from) {
    Map<T,NodePathMinWeight<T>> output = new HashMap<>();
    HeapMap<T> heapMap = new HeapMap<>(graph.nodes.size());
    init(graph, from, heapMap);
    Set<T> isNotVisited = new HashSet<>();
    graph.nodes.forEach((k,v) -> isNotVisited.add(k));
    isNotVisited.remove(from);
    while(!isNotVisited.isEmpty()) {
        HeapMap.NodeWeight<T> currNodeWeight = heapMap.remove();
        output.put(currNodeWeight.nodeVal, new NodePathMinWeight<>(currNodeWeight.nodeVal, currNodeWeight.path, currNodeWeight.weight));
        //mark visited
        isNotVisited.remove(currNodeWeight.nodeVal);
        //neighbors
        Map<T, Integer> neighbors = graph.nodes.get(currNodeWeight.nodeVal).edges;
        neighbors.forEach((k,v) -> {
            int ind = heapMap.index(k);
            HeapMap.NodeWeight<T> neighbor = heapMap.node(ind);
            int neighborDist = neighbor.weight;
            int currentDistance = currNodeWeight.weight;
            if(currentDistance + v < neighborDist) {
                //update
                neighbor.path = new ArrayList<>(currNodeWeight.path);
                heapMap.update(neighbor.nodeVal, currentDistance + v);
            }
        });
    }
    return output;
}

public static void main(String[] args) {
    Graph<Integer> graph = new Graph<>(6,1,2,3,4,5,6);
    graph.addEdge(1,2,2);
    graph.addEdge(1,3,4);
    graph.addEdge(2,3,1);
    graph.addEdge(2,4,7);
    graph.addEdge(3,5,3);
    graph.addEdge(5,6,5);
    graph.addEdge(4,6,1);
    graph.addEdge(5,4,2);

    Integer source = 1;
    Map<Integer,NodePathMinWeight<Integer>> map = dijkstra(graph,source);
    map.forEach((k,v) -> {
        v.path.add(0,source);
        System.out.println("source vertex :["+source+"] to vertex :["+k+"] cost:"+v.minCost+" shortest path :"+v.path);
    });
}

}

输出-:

源顶点:[1] 到顶点:[2] 成本:2 最短路径:[1, 2]

源顶点:[1] 到顶点:[3] 成本:3 最短路径:[1, 2, 3]

源顶点:[1] 到顶点:[4] 成本:8 最短路径:[1, 2, 3, 5, 4]

源顶点:[1] 到顶点:[5] 成本:6 最短路径:[1, 2, 3, 5]

源顶点:[1] 到顶点:[6] 成本:9 最短路径:[1, 2, 3, 5, 4, 6]

于 2021-05-26T14:01:50.337 回答