STL(我认为是 3.3 版)的代码是这样的:
template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*) {
while (__last - __first > 3) {
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
}
让我们稍微简化一下:
template <class Iter, class T>
void nth_element(Iter first, Iter nth, Iter last) {
while (last - first > 3) {
Iter cut =
unguarded_partition(first, last,
T(median(*first,
*(first + (last - first)/2),
*(last - 1))));
if (cut <= nth)
first = cut;
else
last = cut;
}
insertion_sort(first, last);
}
我在这里所做的是删除双下划线和 _Uppercase 的东西,这只是为了保护代码不受用户可以合法定义为宏的东西的影响。我还删除了最后一个参数,它只应该有助于模板类型推导,并为简洁起见重命名了迭代器类型。
正如您现在应该看到的,它重复划分范围,直到剩余范围中剩余的元素少于四个,然后对其进行简单排序。
现在,为什么是 O(n)?首先,最多三个元素的最终排序是 O(1),因为最多三个元素。现在,剩下的就是重复的分区。分区本身就是 O(n)。但是,在这里,每一步都会将下一步需要触及的元素数量减半,所以你有 O(n) + O(n/2) + O(n/4) + O(n/8) 这是如果你总结一下,小于 O(2n)。由于 O(2n) = O(n),因此您平均具有线性复杂度。