1

我尝试使用我的代码找到第 k 个最小元素,但无法修复我的代码中的错误。当它尝试使用 pivot = 0 对 [0, 0, 2] 进行分区时,它正在循环。

import java.util.Arrays;

public class OrderStat {

    public static void main(String[] args) {
        int[] uA = {13, 32, 28, 17, 2, 0, 14, 34, 35, 0};
        System.out.println("Initial array: " + Arrays.toString(uA));
        int kth = 3; // We will try to find 3rd smallest element(or 2nd if we will count from 0).
        int result = getKthSmallestElement(uA, 0, uA.length - 1, kth - 1);
        System.out.println(String.format("The %d smallest element is %d", kth, result));

        System.out.println("-------------------------------------");
        Arrays.sort(uA);
        System.out.println("Sorted array for check: " + Arrays.toString(uA));
    }

    private static int getKthSmallestElement(int[] uA, int start, int end, int kth) {
        int l = start;
        int r = end;
        int pivot = uA[start];
        System.out.println("===================");
        System.out.println(String.format("start=%d end=%d", start, end));
        System.out.println("pivot = " + pivot);

        //ERROR HERE: When we will work with [0, 0, 2] part of array with pivot = 0 it will give us infinite loop;
        while (l < r) {
            while (uA[l] < pivot) {
                l++;
            }
            while (uA[r] > pivot) {
                r--;
            }
            if (l < r) {
                int tmp = uA[l];
                uA[l] = uA[r];
                uA[r] = tmp;
            }
        }
        System.out.println("After partitioning: " + Arrays.toString(uA) + "\n");


        if (l < kth)
            return getKthSmallestElement(uA, l + 1, end, kth);
        else if (l > kth)
            return getKthSmallestElement(uA, start, l - 1, kth);

        return uA[l];
    }

}

请解释一下,如何解决这个问题。

4

1 回答 1

3

交换后

if (l < r) {
    int tmp = uA[l];
    uA[l] = uA[r];
    uA[r] = tmp;
}

您需要移动lr(或至少其中一个,以取得任何进展)到下一个位置 ( ++l; --r;)。否则,如果两个值都等于枢轴,则无限循环。

也可用于快速排序的正确分区是

// make sure to call it only with valid indices, 0 <= start <= end < array.length
private int partition(int[] array, int start, int end) {
    // trivial case, single element array - garbage if end < start
    if(end <= start) return start;
    int pivot = array[start]; // not a good choice of pivot in general, but meh
    int left = start+1, right = end;
    while(left < right) {
        // move left index to first entry larger than pivot or right
        while(left < right && array[left] <= pivot) ++left;
        // move right index to last entry not larger than pivot or left
        while(right > left && array[right] > pivot) --right;
        // Now, either
        // left == right, or
        // left < right and array[right] <= pivot < array[left]
        if (left < right) {
            int tmp = array[left];
            array[left] = array[right];
            array[right] = tmp;
            // move on
            ++left;
            --right;
        }
    }
    // Now left >= right.
    // If left == right, we don't know whether array[left] is larger than the pivot or not,
    // but array[left-1] certainly is not larger than the pivot.
    // If left > right, we just swapped and incremented before exiting the loop,
    // so then left == right+1 and array[right] <= pivot < array[left].
    if (left > right || array[left] > pivot) {
        --left;
    }
    // Now array[i] <= pivot for start <= i <= left, and array[j] > pivot for left < j <= end
    // swap pivot in its proper place in the sorted array
    array[start] = array[left];
    array[left] = pivot;
    // return pivot position
    return left;
}

然后你可以在一个数组中找到第k个最小的元素

int findKthSmallest(int array, int k) {
    if (k < 1) throw new IllegalArgumentException("k must be positive");
    if (array.length < k) throw new IllegalArgumentException("Array too short");
    int left = 0, right = array.length-1, p;
    --k; // 0-based indices
    while(true) {
        p = partition(array, left, right);
        if (p == k) return array[p];
        if (p < k) {
            left = p+1;
            k -= left;
        } else {
            right = p-1;
        }
    }
    // dummy return, never reached
    return 0;
}
于 2012-11-02T21:31:38.453 回答