-1

我尝试将Hoare 分区方案作为 Quickselect 算法的一部分来实现,但它似乎每次都能给我各种答案。

这是在findKthBest给定数组 ( data) 和其中的元素数 ( low = 0high = 4如果有 5 个元素的情况下 ) 中找到数组中第 K 个最大数的函数:

def findKthBest(k, data, low, high):
    # choose random pivot
    pivotindex = random.randint(low, high)

    # move the pivot to the end
    data[pivotindex], data[high] = data[high], data[pivotindex]

    # partition
    pivotmid = partition(data, low, high, data[high])

    # move the pivot back
    data[pivotmid], data[high] = data[high], data[pivotmid]

    # continue with the relevant part of the list
    if pivotmid == k:
        return data[pivotmid]
    elif k < pivotmid:
        return findKthBest(k, data, low, pivotmid - 1)
    else:
        return findKthBest(k, data, pivotmid + 1, high)

该函数partition()有四个变量:

  • data(一个列表,例如 5 个元素),
  • l(列表中相关部分的起始位置,例如0)
  • r(列表中相关部分的结束位置,也是放置枢轴的位置,例如 4)
  • pivot(枢轴的值)
def partition(data, l, r, pivot):
    while True:
        while data[l] < pivot:
            #statistik.nrComparisons += 1
            l = l + 1
        r = r - 1    # skip the pivot
        while r != 0 and data[r] > pivot:
            #statistik.nrComparisons += 1
            r = r - 1
        if r > l:
            data[r], data[l] = data[l], data[r]
        return r

现在我每次都简单地得到各种结果,似乎递归效果不太好(有时它以达到最大递归错误结束),而不是每次都给出一个恒定的结果。我究竟做错了什么?

4

1 回答 1

0

首先,函数中似乎有错误 partition()

如果你仔细比较你的代码和 wiki 中的代码,你会发现不同之处。函数应该是:

def partition(data, l, r, pivot):
    while True:
        while data[l] < pivot:
            #statistik.nrComparisons += 1
            l = l + 1
        r = r - 1    # skip the pivot
        while r != 0 and data[r] > pivot:
            #statistik.nrComparisons += 1
            r = r - 1
        if r >= l:
            return r

        data[r], data[l] = data[l], data[r]

二、例如:

  • data = [1, 0, 2, 4, 3]你得到一个带有pivotmid=3后分区的数组
  • 您想找到第 4 个最大值 ( k=4),即 1

data下一个解析到的数组findKthBest()将变为[1, 0].
因此,接下来findKthBest()应该找到数组的最大值[1, 0]

def findKthBest(k, data, low, high):
    ......

    # continue with the relevant part of the list
    if pivotmid == k:
        return data[pivotmid]
    elif k < pivotmid:
        #Corrected
        return findKthBest(k-pivotmid, data, low, pivotmid - 1)
    else:
        return findKthBest(k, data, pivotmid + 1, high)
于 2019-04-11T10:35:13.090 回答