18

我在任何地方都没有找到这个特定的主题......

我在 23 个整数的 std::vector 中调用 nth_element() 算法约 400,000 次,更精确的“无符号短”值。

我想提高计算速度,而这个特定的调用需要很大一部分 CPU 时间。现在我注意到,与 std::sort() 一样,即使在最高优化级别和 NDEBUG 模式(Linux Clang 编译器)下,nth_element 函数在分析器中也是可见的,因此比较是内联的,而不是函数调用本身。好吧,更确切地说:不是 nth_element() 而是 std::__introselect() 是可见的。

由于数据的大小很小,我尝试使用二次排序函数 PIKSORT,当数据大小小于 20 个元素时,它通常比调用 std::sort 更快,可能是因为该函数将是内联的。

template <class CONTAINER>
inline void piksort(CONTAINER& arr)  // indeed this is "insertion sort"
{
    typename CONTAINER::value_type a;

    const int n = (int)arr.size();
    for (int j = 1; j<n; ++j) {
        a = arr[j];
        int i = j;
        while (i > 0 && a < arr[i - 1]) {
            arr[i] = arr[i - 1];
            i--;
        }
        arr[i] = a;
    }
}

但是,在这种情况下,这比使用 nth_element 慢。

另外,使用统计方法是不合适的,比 std::nth_element 更快的东西

最后,由于值在 0 到大约 20000 的范围内,因此直方图方法看起来不合适。

我的问题:有人知道一个简单的解决方案吗?我想我可能不是唯一一个必须经常调用 std::sort 或 nth_element 的人。

4

3 回答 3

16

You mentioned that the size of the array was always known to be 23. Moreover, the type used is unsigned short. In this case, you might try to use a sorting network of size 23; since your type is unsigned short, sorting the whole array with a sorting network might be even faster than partially sorting it with std::nth_element. Here is a very straightforward C++14 implementation of a sorting network of size 23 with 118 compare-exchange units, as described by Using Symmetry and Evolutionary Search to Minimize Sorting Networks:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}

The swap_if utility function compares two parameters x and y with the predicate compare and swaps them if compare(y, x). My example uses a a generic swap_if function, but you can used an optimized version if you known that you will be comparing unsigned short values with operator< anyway (you might not need such a function if your compiler recognizes and optimizes the compare-exchange, but unfortunately, not all compilers do that - I am using g++5.2 with -O3 and I still need the following function for performance):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}

Now, just to make sure that it is indeed faster, I decided to time std::nth_element when required to partial sort only the first 10 elements vs. sorting the whole 23 elements with the sorting network (1000000 times with different shuffled arrays). Here is what I get:

std::nth_element    1158ms
network_sort23      487ms

That said, my computer has been running for a bit of time and is a bit slow, but the difference in performance is neat. I believe that this difference will remain the same when I restart my computer. I may try it later and let you know.

Regarding how these times were generated, I used a modified version of this benchmark from my cpp-sort library. The original sorting network and swap_if functions come from there as well, so you can be sure that they have been tested more than once :)

EDIT: here are the results now that I have restarted my computer. The network_sort23 version is still two times faster than std::nth_element:

std::nth_element    369ms
network_sort23      154ms

EDIT²: if all you need in the median, you can trivially delete the compare-exchange units that are not needed to compute the final value that will be at the 11th position. The resulting median-finding network of size 23 that follows uses a different size-23 sorting network than the previous one, and it yields slightly better results:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);

There are probably smarter ways to generate median-finding networks, but I don't think that extensive research has been done on the subject. Therefore, it's probably the best method you can use as of now. The result isn't awesome but it still uses 104 compare-exchange units instead of 118.

于 2015-10-24T14:38:17.050 回答
4

大概的概念

查看 MSVC2013 中的源代码std::nth_element,似乎N <= 32的情况是通过插入排序解决的。这意味着 STL 实现者意识到,尽管对于该大小有更好的渐近性,但进行随机分区会更慢。

提高性能的方法之一是优化排序算法。@Morwenn 的回答展示了如何使用排序网络对 23 个元素进行排序,这被认为是对小型恒定大小数组进行排序的最快方法之一。我将研究另一种方法,即在没有排序算法的情况下计算中位数。事实上,我根本不会置换输入数组。

由于我们谈论的是小数组,我们需要以最简单的方式实现一些O(N^2)算法。理想情况下,它应该根本没有分支,或者只有可预测的分支。此外,算法的简单结构可以让我们对其进行矢量化,进一步提高其性能。

算法

我决定遵循计数方法,这里使用它来加速小型线性搜索。首先,假设所有元素都是不同的。选择数组的任何元素:元素的数量小于它在排序数组中定义的位置。我们可以遍历所有元素,并为每个元素计算小于它的元素数。如果排序后的索引具有所需的值,我们可以停止算法。

不幸的是,在一般情况下可能有相同的元素。我们必须使我们的算法明显更慢和更复杂来处理它们。我们可以计算它的可能排序索引的间隔,而不是计算元素的唯一排序索引。对于任何元素,计算小于它的元素数(L)和等于它的元素数(E)就足够了,然后排序索引适合范围[L, L+R)。如果此区间包含所需的排序索引(即N/2),那么我们可以停止算法并返回考虑的元素。

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: https://stackoverflow.com/a/17095534/556899
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}

矢量化

构造的算法只有一个分支,这是相当可预测的:它在所有情况下都失败,除了我们停止算法的唯一情况。该算法很容易使用每个 SSE 寄存器的 8 个元素进行矢量化。因为我们必须在最后一个元素之后访问一些元素,所以我假设输入数组用max=2^15-1值填充,最多 24 或 32 个元素。

第一种方法是通过 向量化内循环j。在这种情况下,内部循环将只执行 3 次,但在完成后必须执行两次 8 宽的缩减。他们比内循环本身吃更多的时间。结果,这样的矢量化不是很有效。

第二种方法是通过 向量化外循环i。在这种情况下,我们一次处理 8 个元素x = arr[i]。对于每个包,我们将其与arr[j]内部循环中的每个元素进行比较。在内循环之后,我们对整个 8 个元素包执行矢量化范围检查。如果其中任何一个成功,我们使用简单的标量代码确定确切的数字(无论如何它消耗的时间很少)。

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}

在这里,我们_mm_set1_epi16在最里面的循环中看到了内在的。GCC 似乎有一些性能问题。无论如何,每次最内层迭代都在消耗时间,如果我们在最内层循环中一次处理 8 个元素也可以减少时间。在这种情况下,我们可以执行 1 个矢量化加载和 14 个解包指令来获取vAll8 个元素。此外,我们必须为循环体中的八个元素编写比较和计数代码,因此它也可以作为 8x 展开。生成的代码是最快的,可以在下面找到它的链接。

比较

我在 Ivy Bridge 3.4 Ghz 处理器上对各种解决方案进行了基准测试。您可以在下面看到2^23 ~= 8M调用的总计算时间(以秒为单位)(第一个数字)。第二个数字是结果的校验和。

MSVC 2013 x64 ( /O2 ) 上的结果:

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j

MinGW GCC 4.8.3 x64 ( -O3 -msse4 ) 的结果:

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)

如您所见,针对 23 个 16 位元素提出的矢量化算法比基于排序的方法要快一点(顺便说一句,在较旧的 CPU 上,我只看到 5% 的时间差异)。如果你能保证所有元素都是不同的,你就可以简化算法,让它更快。

所有算法的完整代码都可以在这里找到,包括所有的测试代码。

于 2015-10-25T03:27:31.247 回答
2

我发现这个问题很有趣,所以我尝试了所有我能想到的算法。
结果如下:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms

我想我们可以得出结论,您已经在使用最快的算法。男孩是我错了。但是,如果您可以接受一个近似答案,则可能有更快的方法,例如median of medians
如果你有兴趣,来源在这里

于 2015-10-24T13:27:06.310 回答