6

我正在查看 GCC-4.8 为 x86_64 生成的代码,并想知道是否有更好(更快)的方法来计算三个值的最小值。

这是 Python 的集合模块的摘录,它计算mrightindex+1和的最小值leftindex

    ssize_t m = n;
    if (m > rightindex + 1)
        m = rightindex + 1;
    if (m > leftindex)
        m = leftindex;

GCC 使用 CMOV 生成串行依赖代码:

leaq    1(%rbp), %rdx
cmpq    %rsi, %rdx
cmovg   %rsi, %rdx
cmpq    %rbx, %rdx
cmovg   %rbx, %rdx

是否有更快的代码可以通过消除数据依赖性来利用处理器乱序并行执行?我想知道是否有已知技巧可以在不使用条件或谓词指令的情况下计算多个值的最小值。我还想知道在这种情况下是否有一些饱和算术内在函数会有所帮助。

编辑:

4

3 回答 3

9

以下是建议方法的基准测试结果。处理器是运行在 4GHz 的 Intel Core-i7 2600K。单位是每循环纳秒(越小越好)。除了被测的实际 min3 代码外,测量还包括伪随机测试数据生成和函数调用开销。

                          gcc cmov      conditional      classical
                          sequence      branches         branchless

pseudo-random data        2.24          6.31             2.39
fixed data                2.24          2.99             2.39

基准测试源代码

functions.s:评估的 3 个解决方案:

//----------------------------------------------------------------------------

.text
.intel_syntax noprefix

//-----------------------------------------------------------------------------
// uint64_t min3_a (uint64_t rcx, uint64_t rdx, uint64_t r8)
.globl min3_a
min3_a:
   mov   rax, rcx
   cmp   rax, rdx
   cmovg rax, rdx
   cmp   rax, r8
   cmovg rax, r8
   retq

//-----------------------------------------------------------------------------
// uint64_t min3_b (uint64_t rcx, uint64_t rdx, uint64_t r8)
.globl min3_b
min3_b:
   mov   rax, rcx
   cmp   rax, rdx
   jle   skip1
   mov   rax, rdx
skip1:
   cmp   rax, r8
   jle   skip2
   mov   rax, r8
skip2:
   retq

//-----------------------------------------------------------------------------
// uint64_t min3_c (uint64_t rcx, uint64_t rdx, uint64_t r8)
.globl min3_c
min3_c:
   sub   rdx, rcx
   sbb   rax, rax
   and   rax, rdx
   add   rcx, rax
   sub   r8,  rcx
   sbb   rax, rax
   and   rax, r8
   add   rax, rcx
   retq

//-----------------------------------------------------------------------------

min3test.c,主程序(mingw64/windows):

#define __USE_MINGW_ANSI_STDIO 1
#include <windows.h>
#include <stdio.h>
#include <stdint.h>

 uint64_t min3_a (uint64_t rcx, uint64_t rdx, uint64_t r8);
 uint64_t min3_b (uint64_t rcx, uint64_t rdx, uint64_t r8);
 uint64_t min3_c (uint64_t rcx, uint64_t rdx, uint64_t r8);

//-----------------------------------------------------------------------------
//
//  queryPerformanceCounter - similar to QueryPerformanceCounter, but returns
//                            count directly.

uint64_t queryPerformanceCounter (void)
    {
    LARGE_INTEGER int64;
    QueryPerformanceCounter (&int64);
    return int64.QuadPart;
    }

//-----------------------------------------------------------------------------
//
// queryPerformanceFrequency - same as QueryPerformanceFrequency, but returns  count direcly.

uint64_t queryPerformanceFrequency (void)
    {
    LARGE_INTEGER int64;

    QueryPerformanceFrequency (&int64);
    return int64.QuadPart;
    }

//----------------------------------------------------------------------------
//
// lfsr64gpr - left shift galois type lfsr for 64-bit data, general purpose register implementation
//
static uint64_t lfsr64gpr (uint64_t data, uint64_t mask)
   {
   uint64_t carryOut = data >> 63;
   uint64_t maskOrZ = -carryOut; 
   return (data << 1) ^ (maskOrZ & mask);
   }

//---------------------------------------------------------------------------

uint64_t min3 (uint64_t a, uint64_t b, uint64_t c)
    {
    uint64_t smallest;

    smallest = a;
    if (smallest > b) smallest = b;
    if (smallest > c) smallest = c;
    return smallest;
    }

//---------------------------------------------------------------------------

static void runtests (uint64_t pattern, uint64_t mask)
    {
    uint64_t startCount, elapsed, index, loops = 800000000;
    double ns;

    startCount = queryPerformanceCounter ();
    for (index = 0; index < loops; index++)
        {
        pattern = lfsr64gpr (pattern, mask);
        min3_a (pattern & 0xFFFFF, (pattern >> 20) & 0xFFFFF, (pattern >> 40) & 0xFFFFF);
        }
    elapsed = queryPerformanceCounter () - startCount;
    ns = (double) elapsed / queryPerformanceFrequency () * 1000000000 / loops;
    printf ("%7.2f ns\n", ns);

    startCount = queryPerformanceCounter ();
    for (index = 0; index < loops; index++)
        {
        pattern = lfsr64gpr (pattern, mask);
        min3_b (pattern & 0xFFFFF, (pattern >> 20) & 0xFFFFF, (pattern >> 40) & 0xFFFFF);
        }
    elapsed = queryPerformanceCounter () - startCount;
    ns = (double) elapsed / queryPerformanceFrequency () * 1000000000 / loops;
    printf ("%7.2f ns\n", ns);

    startCount = queryPerformanceCounter ();
    for (index = 0; index < loops; index++)
        {
        pattern = lfsr64gpr (pattern, mask);
        min3_c (pattern & 0xFFFFF, (pattern >> 20) & 0xFFFFF, (pattern >> 40) & 0xFFFFF);
        }
    elapsed = queryPerformanceCounter () - startCount;
    ns = (double) elapsed / queryPerformanceFrequency () * 1000000000 / loops;
    printf ("%7.2f ns\n", ns);
    }

//---------------------------------------------------------------------------

int main (void)
    {
    uint64_t index, pattern, mask;
    uint64_t count_a = 0, count_b = 0, count_c = 0;

    mask = 0xBEFFFFFFFFFFFFFF;
    pattern = 1;

    // sanity check the asm functions
    for (index = 0; index < 1000000; index++)
        {
        uint64_t expected, result_a, result_b, result_c;
        uint64_t pattern1 = (pattern >>  0) & 0xFFFFF;
        uint64_t pattern2 = (pattern >> 20) & 0xFFFFF;
        uint64_t pattern3 = (pattern >> 40) & 0xFFFFF;
        pattern = lfsr64gpr (pattern, mask);
        expected = min3 (pattern1, pattern2, pattern3);
        result_a = min3_a (pattern1, pattern2, pattern3);
        result_b = min3_b (pattern1, pattern2, pattern3);
        result_c = min3_c (pattern1, pattern2, pattern3);
        if (result_a != expected) printf ("min3_a fail, %llu %llu %llu %llu %llu\n", expected, result_a, pattern1, pattern2, pattern3);
        if (result_b != expected) printf ("min3_b fail, %llu %llu %llu %llu %llu\n", expected, result_b, pattern1, pattern2, pattern3);
        if (result_c != expected) printf ("min3_c fail, %llu %llu %llu %llu %llu\n", expected, result_c, pattern1, pattern2, pattern3);
        if (expected == pattern1) count_a++;
        if (result_b == pattern2) count_b++;
        if (result_c == pattern3) count_c++;
        }
    printf ("pseudo-random distribution: %llu, %llu, %llu\n", count_a, count_b, count_c);


    // raise our priority to increase measurement accuracy
    SetPriorityClass (GetCurrentProcess (), REALTIME_PRIORITY_CLASS);

    printf ("using pseudo-random data\n");
    runtests (1, mask);
    printf ("using fixed data\n");
    runtests (0, mask);
    return 0;
    }

//---------------------------------------------------------------------------

构建命令行:

gcc -Wall -Wextra -O3 -omin3test.exe min3test.c functions.s
于 2013-08-25T01:57:04.440 回答
6

最少两个无符号数有经典解:

; eax = min(eax, ebx), ecx - scratch register.
.min2:
    sub     ebx, eax
    sbb     ecx, ecx
    and     ecx, ebx
    add     eax, ecx

这种方法可能比使用 cmov 的解决方案更快,但为了获得更高的速度,必须将指令与其他指令分开才能并行执行。

可以对三个数字实现此方法:

; eax = min(eax, ebx, edx), ecx - scratch register.
.min3:
    sub     ebx, eax
    sbb     ecx, ecx
    and     ecx, ebx
    add     eax, ecx

    sub     edx, eax
    sbb     ecx, ecx
    and     ecx, edx
    add     eax, ecx

另一种尝试是使用条件跳转来测试变体。对于现代处理器,它可能会更快,特别是如果跳跃是高度可预测的:

.min3:
    cmp     eax, ebx
    jle     @f
    mov     eax, ebx
@@:
    cmp     eax, edx
    jle     @f
    mov     eax, edx
@@:
于 2013-08-19T17:39:23.910 回答
1

该函数min(x,y,z)是连续的,但它的导数不是。该导数在其定义的任何地方都有范数 1。只是没有办法将其表达为算术函数。

饱和算术有其自身的不连续性,因此在这种情况下不能使用前面的推理。但是,饱和点与输入无关。这反过来意味着您需要缩放输入,此时我相信生成的代码不会更快。

这当然不能完全证明不存在更快的代码,但您可能需要进行详尽的搜索。

于 2013-08-19T07:55:28.530 回答