3

我需要知道具有存储在 __m128 中的最大绝对值的值的符号。这是我现在的解决方案:

int getMaxSign(__m128 const& vec) {
    static const __m128 SIGN_BIT_MASK = 
      _mm_castsi128_ps(_mm_set1_epi32(0x80000000));

    // This creates an int, where sign(a) is 1 if a is negative, 0 o.w.:
    // sign(a3)<<3 | sign(a2)<<2 | sign(a1)<<1 | sign(a0)
    const int signMask = _mm_movemask_ps(vec);

    // Get the absolute value of the vector;
    __m128 absValsMMX = _mm_andnot_ps(SIGN_BIT_MASK, vec);

    // Figure out the horizontal max
    __declspec(align(16)) float absVals[4];
    _mm_store_ps(absVals, absValsMMX);

    const float maxVal = std::max(std::max(absVals[0], absVals[1]), absVals[2]);

    return (maxVal == absVals[0] ? signMask & 0x1 : 
      (maxVal == absVals[1] ? signMask & 0x2 : signMask & 0x4));
}

在这种情况下,如果具有最大绝对值的值为负数,则符号将为 1,否则为 0,但我实际上并不关心约定是什么。另一件事是我使用这些 __m128s 表示同质向量,所以我知道最后一个值将始终为 0。

对于一个相对简单的任务,这似乎需要做很多工作。我怎样才能更快地做到这一点?

谢谢!

4

3 回答 3

4

这是一种可能的实现(在 C 中):

int getMaxSign(const __m128 v)
{
    __m128 v1, vmax, vmin, vsign;
    float sign;

    v1 = (__m128)_mm_alignr_epi8((__m128i)v, (__m128i)v, 4); // v1 = v rotated by 1 element
    vmax = _mm_max_ps(v, v1);           // generate horizontal max/min
    vmin = _mm_min_ps(v, v1);
    vmax = _mm_max_ps(vmax, (__m128)_mm_alignr_epi8((__m128i)vmax, (__m128i)vmax, 8));
    vmin = _mm_min_ps(vmin, (__m128)_mm_alignr_epi8((__m128i)vmin, (__m128i)vmin, 8));
    vsign = _mm_add_ps(vmax, vmin);     // add max and min to get sign of abs max
    sign = _mm_extract_ps(vsign, 0);
    return (int)(sign < 0.0f);          // return 1 for negative
}

尽管这看起来像很多代码,但它只有大约 9 条 SSE 指令,并且没有内存访问、没有分支和很少的标量代码。

请注意,上面同时使用了 SSSE3 和 SSE4.1 指令。

这是仅需要 SSSE3 的第二个版本:

int getMaxSign(const __m128 v)
{
    __m128 v1, vmax, vmin, vsign;
    int mask;

    v1 = (__m128)_mm_alignr_epi8((__m128i)v, (__m128i)v, 4); // v1 = v rotated by 1 element
    vmax = _mm_max_ps(v, v1);           // generate horizontal max/min
    vmin = _mm_min_ps(v, v1);
    vmax = _mm_max_ps(vmax, (__m128)_mm_alignr_epi8((__m128i)vmax, (__m128i)vmax, 8));
    vmin = _mm_min_ps(vmin, (__m128)_mm_alignr_epi8((__m128i)vmin, (__m128i)vmin, 8));
    vsign = _mm_add_ps(vmax, vmin);     // add max and min to get sign of abs max
    mask = _mm_movemask_epi8((__m128i)vsign);
    return (mask & 8) != 0;             // return 1 for negative
}

这会生成 12 条指令:

pshufd  $57, %xmm0, %xmm1
movdqa  %xmm0, %xmm2
minps   %xmm1, %xmm2
pshufd  $78, %xmm2, %xmm3
minps   %xmm3, %xmm2
maxps   %xmm1, %xmm0
pshufd  $78, %xmm0, %xmm1
maxps   %xmm1, %xmm0
addps   %xmm2, %xmm0
pmovmskb    %xmm0, %eax
shrl    $3, %eax
andl    $1, %eax

请注意编译器如何巧妙地更改palignr并仅使用 a和 anpshufd实现最终的标量测试。shrlandl


Visual Studio C/C++ 的注意事项:在 and 之间进行转换__m128__m128i您需要使用_mm_castps_si128and _mm_castsi128_ps,例如

    mask = _mm_movemask_epi8((__m128i)vsign);

将需要更改为:

    mask = _mm_movemask_epi8(_mm_castps_si128(vsign));
于 2012-11-26T15:56:31.943 回答
0

如果您的数字是离散的,并且间隔适当,并且来自有限的子集,那么还有其他可能性。

例如,如果您保证 a、b 和 c 是整数,那么您可以将向量自身相乘以获得奇次幂,然后用 <1, 1, 1> 点它。例如,如果我们将它自身相乘 4 次,它会得到 < a^5, b^5, c^5 >。如果 |a| 是最大的并且|a|=2,那么我们知道b和c将为1或0,所以a^3的值将占主导地位,点积将有其符号。例如,如果 X= < a=-2, b=1, c=0 > ,则 X^5 = <-32, 1, 0>。当你用 <1, 1, 1> 点它时,你得到 -31,它的符号反映了最大绝对值的符号。随着最大数的绝对值增加,它与其他项之间的差异将趋于收敛——例如,如果我们有 <-8, 7, 7>,那么上面的算法给出 X^5=<-32768 , 16807, 16807>,你用 <1 来点缀它,1, 1> 并得到 846,因此算法以指数 5 失败。如果我们将指数提高到 7,我们得到 <-2097152, 823543, 823543>,点缀着 <1, 1, 1> 给我们 -450066,这是正确的答案。最终舍入错误也会破坏这种方法。但是,如果您知道数据集的限制,我希望它可以对其他替代方案提供一些见解。

作为脚注,请记住 X^5 = (X*X) * (X*X) * X,因此您进行一次乘以得到 X^2,将其乘以自身得到 X^4,然后乘以 X - 总共三倍。您需要一个奇数指数来保留符号。

于 2012-11-26T16:40:25.293 回答
0
m = min(a,b,c);  
M = max(a,b,c);  

// return abs(m)>abs(M) ? sign(m): sign(M);   // was
return sign(m+M);

正如 Paul_R 正确注意到的那样,符号仅来自最小值和最大值的总和。具有较大(相反符号)绝对值的,获胜。

但是这个想法可以被更多地利用:min/max 的总和是相同的,作为所有元素的总和,减去中间的那个,可以通过 max 3 比较找到。

return sign(a+b+c - middle(a,b,c));  // or
return sign(a*aw + b*bw + c*cw);     // where aw,bw,cw = [0,1]

aw,bw,cw 可以从赢得比较的次数中得出(我认为必须为这种情况仔细计划,当有 2 或 3 个相等的值时。)

并进一步:

x = abs(b)>abs(a)?b:a;
return sign(x+c);

可能更进一步:

s = sign(a + b);  // store the sign of larger of a or b  
a = abs(a); b=abs(b);  
a = max(a,b) | s;   // somehow copy the sign.  
return sign(a+c);  
于 2012-11-26T15:03:39.743 回答