32

StackOverflow 和其他地方有很多声称nth_elementO(n)并且通常使用 Introselect 实现:http ://en.cppreference.com/w/cpp/algorithm/nth_element

我想知道如何实现这一点。我查看了维基百科对 Introselect 的解释,这让我更加困惑。算法如何在 QSort 和 Median-of-Medians 之间切换?

我在这里找到了 Introsort 论文:http ://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.5196&rep=rep1&type=pdf但上面写着:

在本文中,我们专注于排序问题,并在后面的部分中仅简要地回到选择问题。

我试图通读 STL 本身以了解它nth_element是如何实现的,但这确实很快。

有人可以向我展示如何实现 Introselect 的伪代码吗?甚至更好,当然是 STL 以外的实际 C++ 代码 :)

4

3 回答 3

22

免责声明:我不知道std::nth_element在任何标准库中是如何实现的。

如果您知道快速排序的工作原理,您可以轻松地对其进行修改以执行该算法所需的操作。快速排序的基本思想是,在每一步中,将数组分成两部分,使得所有小于枢轴的元素都在左子数组中,所有等于或大于枢轴的元素都在右子数组中. (快速排序的修改称为三元快速排序创建第三个子数组,其中所有元素都等于枢轴。然后右子数组只包含严格大于枢轴的条目。)然后快速排序通过递归排序左子和右子继续-阵列。

如果您只想将第n个元素移动到位,而不是递归到两个子数组中,您可以在每一步中告诉您是否需要下降到左子数组或右子数组。(您知道这一点,因为已排序数组中的第n个元素具有索引n,因此它变成了比较索引的问题。)所以——除非你的快速排序遭受最坏情况的退化——你在每一步中将剩余数组的大小大致减半. (您永远不会再查看另一个子数组。)因此,平均而言,您在每个步骤中处理以下长度的数组:

  1. Θ( N )
  2. Θ( N / 2)
  3. Θ( N / 4)
  4. …</li>

每个步骤在它所处理的数组的长度上都是线性的。(您遍历它一次并根据它与枢轴的比较来决定每个元素应该进入哪个子数组。)

您可以看到,经过 Θ(log( N )) 步骤后,我们最终将到达一个单例数组并完成。如果你总结N (1 + 1/2 + 1/4 + ...),你会得到 2 N。或者,在平均情况下,因为我们不能希望枢轴总是正好是中位数,大约是 Θ( N )。

于 2015-03-19T13:42:17.167 回答
16

你问了两个问题,名义上的一个

nth_element 是如何实现的?

你已经回答了:

StackOverflow 和其他地方有很多声称 nth_element 是 O(n) 并且它通常使用 Introselect 实现。

我也可以通过查看我的 stdlib 实现来确认这一点。(稍后会详细介绍。)

还有一个你不明白答案的地方:

算法如何在 QSort 和 Median-of-Medians 之间切换?

让我们看看我从 stdlib 中提取的伪代码:

nth_element(first, nth, last)
{ 
  if (first == last || nth == last)
    return;

  introselect(first, nth, last, log2(last - first) * 2);
}

introselect(first, nth, last, depth_limit)
{
  while (last - first > 3)
  {
      if (depth_limit == 0)
      {
          // [NOTE by editor] This should be median-of-medians instead.
          // [NOTE by editor] See Azmisov's comment below
          heap_select(first, nth + 1, last);
          // Place the nth largest element in its final position.
          iter_swap(first, nth);
          return;
      }
      --depth_limit;
      cut = unguarded_partition_pivot(first, last);
      if (cut <= nth)
        first = cut;
      else
        last = cut;
  }
  insertion_sort(first, last);
}

在不深入了解引用函数的细节的情况下heap_selectunguarded_partition_pivot我们可以清楚地看到,这nth_element给出了 introselect2 * log2(size)细分步骤(在最好的情况下是 quickselect 所需的两倍),直到heap_select启动并永久解决问题。

于 2015-03-19T14:05:03.920 回答
10

STL(我认为是 3.3 版)的代码是这样的:

template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
                   _RandomAccessIter __last, _Tp*) {
  while (__last - __first > 3) {
    _RandomAccessIter __cut =
      __unguarded_partition(__first, __last,
                            _Tp(__median(*__first,
                                         *(__first + (__last - __first)/2),
                                         *(__last - 1))));
    if (__cut <= __nth)
      __first = __cut;
    else 
      __last = __cut;
  }
  __insertion_sort(__first, __last);
}

让我们稍微简化一下:

template <class Iter, class T>
void nth_element(Iter first, Iter nth, Iter last) {
  while (last - first > 3) {
    Iter cut =
      unguarded_partition(first, last,
                          T(median(*first,
                                   *(first + (last - first)/2),
                                   *(last - 1))));
    if (cut <= nth)
      first = cut;
    else 
      last = cut;
  }
  insertion_sort(first, last);
}

我在这里所做的是删除双下划线和 _Uppercase 的东西,这只是为了保护代码不受用户可以合法定义为宏的东西的影响。我还删除了最后一个参数,它只应该有助于模板类型推导,并为简洁起见重命名了迭代器类型。

正如您现在应该看到的,它重复划分范围,直到剩余范围中剩余的元素少于四个,然后对其进行简单排序。

现在,为什么是 O(n)?首先,最多三个元素的最终排序是 O(1),因为最多三个元素。现在,剩下的就是重复的分区。分区本身就是 O(n)。但是,在这里,每一步都会将下一步需要触及的元素数量减半,所以你有 O(n) + O(n/2) + O(n/4) + O(n/8) 这是如果你总结一下,小于 O(2n)。由于 O(2n) = O(n),因此您平均具有线性复杂度。

于 2015-03-19T13:42:55.933 回答