14

您知道Java中(二进制)段树的良好实现吗?

4

4 回答 4

13
public class SegmentTree {
    public static class STNode {
        int leftIndex;
        int rightIndex;
        int sum;
        STNode leftNode;
        STNode rightNode;
    }

    static STNode constructSegmentTree(int[] A, int l, int r) {
        if (l == r) {
            STNode node = new STNode();
            node.leftIndex = l;
            node.rightIndex = r;
            node.sum = A[l];
            return node;
        }
        int mid = (l + r) / 2;
        STNode leftNode = constructSegmentTree(A, l, mid);
        STNode rightNode = constructSegmentTree(A, mid+1, r);
        STNode root = new STNode();
        root.leftIndex = leftNode.leftIndex;
        root.rightIndex = rightNode.rightIndex;
        root.sum = leftNode.sum + rightNode.sum;
        root.leftNode = leftNode;
        root.rightNode = rightNode;
        return root;
    }

    static int getSum(STNode root, int l, int r) {
        if (root.leftIndex >= l && root.rightIndex <= r) {
            return root.sum;
        }
        if (root.rightIndex < l || root.leftIndex > r) {
            return 0;
        }
        return getSum(root.leftNode, l, r) + getSum(root.rightNode, l, r);
    }

    /**
     * 
     * @param root
     * @param index index of number to be updated in original array 
     * @param newValue
     * @return difference between new and old values
     */
    static int updateValueAtIndex(STNode root, int index, int newValue) {
        int diff = 0;
        if(root.leftIndex==root.rightIndex && index == root.leftIndex) {
            // We actually reached to the leaf node to be updated
            diff = newValue-root.sum;
            root.sum=newValue;
            return diff;
        }
        int mid = (root.leftIndex + root.rightIndex) / 2;
        if (index <= mid) {
            diff= updateValueAtIndex(root.leftNode, index, newValue);
        } else {
            diff= updateValueAtIndex(root.rightNode, index, newValue);
        }
        root.sum+=diff;
        return diff;
    }
}
于 2014-01-21T07:57:47.143 回答
3

这已在开源Layout Management SW Package 项目中实现

这是子包的链接

您可能会发现代码很有用。我既没有验证它也没有运行它,我无法通过快速搜索代码和网站找到提供代码的许可证,因此请注意 Emptor。

您也许可以联系作者,但最后一次活动似乎是 2008 年 8 月。

于 2009-04-28T12:49:19.120 回答
1

这里是:

import java.util.Scanner;

public class MinimumSegmentTree {

    static Scanner in = new Scanner(System.in);

    public static void main(String[] args) {
        final int n = in.nextInt();
        int[] a = new int[n];

        for (int i = 0; i < n; i++) {
            a[i] = in.nextInt();
        }

        int sizeOfSegmentTree = (int) Math.pow(2, Math.ceil(Math.log10(n) / Math.log10(2)));
        sizeOfSegmentTree = 2*sizeOfSegmentTree-1;

//      System.out.println(sizeOfSegmentTree);

        int[] segmentTree = new int[sizeOfSegmentTree];
        formSegmentTree(a, segmentTree, 0, n-1, 0);

//      for(int i=0; i<sizeOfSegmentTree; i++){
//          System.out.print(segmentTree[i]+" ");
//      }
//      System.out.println();

        final int q = in.nextInt();
        for (int i = 0; i < q; i++) {
            int s, e;
            s = in.nextInt();
            e = in.nextInt();

            int minOverRange = getMinimumOverRange(segmentTree, s, e, 0, n-1, 0);
            System.out.println(minOverRange);
        }
    }

    private static int getMinimumOverRange(int[] segmentTree, int qs, int qe, int s, int e, int pos) {
        if (qs <= s && qe >= e) {
            return segmentTree[pos];
        }
        if (qs > e || s > qe) {
            return 10000000;
        }

        int mid = (s + e) / 2;
        return Math.min(getMinimumOverRange(segmentTree, qs, qe, s, mid, 2 * pos + 1),
                getMinimumOverRange(segmentTree, qs, qe, mid+1, e, 2 * pos + 2));
    }

    private static void formSegmentTree(int[] a, int[] segmentTree, int s, int e, int pos) {
        if (e - s == 0) {
            segmentTree[pos] = a[s];
            return;

        }

        int mid = (s + e) / 2;

        formSegmentTree(a, segmentTree, s, mid, 2 * pos + 1);
        formSegmentTree(a, segmentTree, mid+1, e, 2 * pos + 2);

        segmentTree[pos] = Math.min(segmentTree[2 * pos + 1], segmentTree[2 * pos + 2]);

    }

}
于 2015-12-30T15:08:29.210 回答
0

算法和单元测试

public class NumArrayTest {

        @Test
        public void testUpdateSumRange_WithEmpty() throws Exception {
            NumArray numArray = new NumArray(new int[]{});
            assertEquals(0, numArray.sumRange(0, 0));
        }

        @Test
        public void testUpdateSumRange_WithSingleton() throws Exception {
            NumArray numArray = new NumArray(new int[]{1});
            assertEquals(1, numArray.sumRange(0, 0));
            numArray.update(0, 2);
            assertEquals(2, numArray.sumRange(0, 0));
        }

        @Test
        public void testUpdateSumRange_WithPairElements() throws Exception {
            NumArray numArray = new NumArray(new int[]{1,2,3,4,5,6});
            assertEquals(12, numArray.sumRange(2, 4));
            numArray.update(3, 2);
            assertEquals(10, numArray.sumRange(2, 4));
        }

        @Test
        public void testUpdateSumRange_WithInPairElements() throws Exception {
            NumArray numArray = new NumArray(new int[]{1,2,3,4,5,6,7});
            assertEquals(12, numArray.sumRange(2, 4));
            numArray.update(3, 2);
            assertEquals(10, numArray.sumRange(2, 4));
        }
    }



public class NumArray {

    private final Node root;

    private static class Node {
        private final int begin;
        private final int end;
        private final Node left;
        private final Node right;
        private int sum;

        public Node(int begin, int end, int sum, Node left, Node right) {
            this.begin = begin;
            this.end = end;
            this.sum = sum;
            this.left = left;
            this.right = right;
        }

        public boolean isSingle() {
            return begin == end;
        }

        public boolean contains(int i) {
            return i >= begin && i <= end;
        }

        public boolean inside(int i, int j) {
            return i <= begin && j >= end;
        }

        public boolean outside(int i, int j) {
            return i > end || j < begin;
        }

        public void setSum(int sum) {
            this.sum = sum;
        }
    }

    public NumArray(int[] nums) {
        if (nums.length == 0) {
            root = null;
        } else {
            root = buildNode(nums, 0, nums.length - 1);
        }
    }

    private Node buildNode(int[] nums, int begin, int end) {
        if (begin == end) {
            return new Node(begin, end, nums[begin], null, null);
        } else {
            int mid = (begin + end) / 2 + 1;
            Node left = buildNode(nums, begin, mid - 1);
            Node right = buildNode(nums, mid, end);
            return new Node(begin, end, left.sum + right.sum, left, right);
        }
    }

    public void update(int i, int val) {
        if (root == null) {
            return;
        }
        if (!root.contains(i)) {
            throw new IllegalArgumentException("i not in range");
        }
        update(root, i, val);
    }

    private int update(Node node, int i, int val) {
        if (node.isSingle()) {
            node.setSum(val);
        } else {
            Node nodeToUpdate = node.left.contains(i) ? node.left : node.right;
            int withoutNode = node.sum - nodeToUpdate.sum;
            node.setSum(withoutNode + update(nodeToUpdate, i, val));
        }
        return node.sum;
    }

    public int sumRange(int i, int j) {
        if (root == null) {
            return 0;
        }
        return sumRange(root, i, j);
    }

    private int sumRange(Node node, int i, int j) {
        if (node.outside(i, j)) {
            return 0;
        } else if (node.inside(i, j)) {
            return node.sum;
        } else {
            return sumRange(node.left, i, j) + sumRange(node.right, i, j);
        }
    }

}
于 2015-12-03T15:00:48.180 回答