4

如何在不循环的情况下实现对位掩码的操作,这对于两个位掩码ab宽度n给出具有以下属性的宽度位掩码:c2 * n

  • ic仅当存在j-th 位ak-th 位 in时才设置 -thbj + k == i

C++ 实现:

#include <bitset>
#include <algorithm>
#include <iostream>

#include <cstdint>
#include <cassert>

#include <x86intrin.h>

std::uint64_t multishift(std::uint32_t a, std::uint32_t b)
{
    std::uint64_t c = 0;
    if (_popcnt32(b) < _popcnt32(a)) {
        std::swap(a, b);
    }
    assert(a != 0);
    do {
        c |= std::uint64_t{b} << (_bit_scan_forward(a) + 1);
    } while ((a &= (a - 1)) != 0); // clear least set bit
    return c;
}

int main()
{
    std::cout << std::bitset< 64 >(multishift(0b1001, 0b0101)) << std::endl; // ...0001011010
}

可以使用一些位技巧或一些 x86 指令在不循环的情况下重新实现它吗?

4

2 回答 2

6

这就像使用 OR 而不是 ADD 的乘法。据我所知,没有真正惊人的技巧。但这是一个实际上避免内在函数而不是使用它们的技巧:

while (a) {
    c |= b * (a & -a);
    a &= a - 1;
}

这与您的算法非常相似,但使用乘法b通过尾随零计数左移aa & -a这是仅选择最低设置位作为掩码的技巧。作为奖励,该表达式可以安全地执行 when a == 0,因此您可以展开(和/或将 while 转换为 do/while 而没有先决条件)而不会出现令人讨厌的边缘情况(TZCNT 和 shift 不是这种情况)。


pshufb可以在并行查表模式下使用,使用一个半字节a来选择一个子表,然后使用它在一条指令中将所有半字节乘以b该半字节。a对于适当的乘法,pshufb最大为 8 秒(或始终为 8,因为尝试提前退出的意义较小)。它确实需要在开始时进行一些奇怪的设置,并需要一些不幸的水平东西来完成它,所以它可能不是那么好。

于 2018-03-01T19:28:33.233 回答
3

由于这个问题被标记为 BMI(Intel Haswell 和更新版本,AMD Excavator 和更新版本),我认为 AVX2 在这里也很有趣。请注意,所有支持 BMI2 的处理器也支持 AVX2。在英特尔 Skylake 等现代硬件上,与标量(do/while)循环相比,蛮力 AVX2 方法可能表现得相当好,除非ab只有很少的位设置。下面的代码在不搜索输入中的非零位的情况下计算所有 32 次乘法(参见 Harold 的乘法思想)。在相乘之后,所有东西都被 OR-ed 在一起。

对于 的随机输入值a,标量代码(见上文:Orient 的问题或 Harold 的回答)可能会遭受分支错误预测:每个周期的指令数 (IPC) 小于无分支 AVX2 解决方案的 IPC。我用不同的代码做了一些测试。事实证明,Harold 的代码可能会极大地受益于循环展开。在这些测试中,测量的是吞吐量,而不是延迟。考虑两种情况: 1.a平均设置 3.5 位的随机数。a2.平均设置为 16.0 位的随机数 (50%)。中央处理器是英特尔 Skylake i5-6500。

Time in sec.
                             3.5 nonz      16 nonz  
mult_shft_Orient()             0.81          1.51                
mult_shft_Harold()             0.84          1.51                
mult_shft_Harold_unroll2()     0.64          1.58                
mult_shft_Harold_unroll4()     0.48          1.34                
mult_shft_AVX2()               0.44          0.40               

请注意,这些时间包括随机输入数的生成ab. AVX2 代码受益于快速vpmuludq指令:它在 Skylake 上每个周期的吞吐量为 2。


编码:

/*      gcc -Wall -m64 -O3 -march=broadwell mult_shft.c    */
#include <stdint.h>
#include <stdio.h>
#include <x86intrin.h>

uint64_t mult_shft_AVX2(uint32_t a_32, uint32_t b_32) {

    __m256i  a          = _mm256_broadcastq_epi64(_mm_cvtsi32_si128(a_32));
    __m256i  b          = _mm256_broadcastq_epi64(_mm_cvtsi32_si128(b_32));
                                                           //  0xFEDCBA9876543210  0xFEDCBA9876543210  0xFEDCBA9876543210  0xFEDCBA9876543210     
    __m256i  b_0        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000008, 0x0000000000000004, 0x0000000000000002, 0x0000000000000001));
    __m256i  b_1        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000080, 0x0000000000000040, 0x0000000000000020, 0x0000000000000010));
    __m256i  b_2        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000800, 0x0000000000000400, 0x0000000000000200, 0x0000000000000100));
    __m256i  b_3        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000008000, 0x0000000000004000, 0x0000000000002000, 0x0000000000001000));
    __m256i  b_4        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000080000, 0x0000000000040000, 0x0000000000020000, 0x0000000000010000));
    __m256i  b_5        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000800000, 0x0000000000400000, 0x0000000000200000, 0x0000000000100000));
    __m256i  b_6        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000008000000, 0x0000000004000000, 0x0000000002000000, 0x0000000001000000));
    __m256i  b_7        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000080000000, 0x0000000040000000, 0x0000000020000000, 0x0000000010000000));

    __m256i  m_0        = _mm256_mul_epu32(a, b_0);
    __m256i  m_1        = _mm256_mul_epu32(a, b_1);
    __m256i  m_2        = _mm256_mul_epu32(a, b_2);
    __m256i  m_3        = _mm256_mul_epu32(a, b_3);
    __m256i  m_4        = _mm256_mul_epu32(a, b_4);
    __m256i  m_5        = _mm256_mul_epu32(a, b_5);
    __m256i  m_6        = _mm256_mul_epu32(a, b_6);
    __m256i  m_7        = _mm256_mul_epu32(a, b_7);

             m_0        = _mm256_or_si256(m_0, m_1);     
             m_2        = _mm256_or_si256(m_2, m_3);     
             m_4        = _mm256_or_si256(m_4, m_5);     
             m_6        = _mm256_or_si256(m_6, m_7);
             m_0        = _mm256_or_si256(m_0, m_2);     
             m_4        = _mm256_or_si256(m_4, m_6);     
             m_0        = _mm256_or_si256(m_0, m_4);     

    __m128i  m_0_lo     = _mm256_castsi256_si128(m_0);
    __m128i  m_0_hi     = _mm256_extracti128_si256(m_0, 1);
    __m128i  e          = _mm_or_si128(m_0_lo, m_0_hi);
    __m128i  e_hi       = _mm_unpackhi_epi64(e, e);
             e          = _mm_or_si128(e, e_hi);
    uint64_t c          = _mm_cvtsi128_si64x(e);
                          return c; 
}


uint64_t mult_shft_Orient(uint32_t a, uint32_t b) {
    uint64_t c = 0;
    do {
        c |= ((uint64_t)b) << (_bit_scan_forward(a) );
    } while ((a = a & (a - 1)) != 0);
    return c; 
}


uint64_t mult_shft_Harold(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}


uint64_t mult_shft_Harold_unroll2(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}


uint64_t mult_shft_Harold_unroll4(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}



int main(){

uint32_t a,b;

/*
uint64_t c0, c1, c2, c3, c4;
a = 0x10036011;
b = 0x31000107;

//a = 0x80000001;
//b = 0x80000001;

//a = 0xFFFFFFFF;
//b = 0xFFFFFFFF;

//a = 0x00000001;
//b = 0x00000001;

//a = 0b1001;
//b = 0b0101;

c0 = mult_shft_Orient(a, b);        
c1 = mult_shft_Harold(a, b);        
c2 = mult_shft_Harold_unroll2(a, b);
c3 = mult_shft_Harold_unroll4(a, b);
c4 = mult_shft_AVX2(a, b);          
printf("%016lX \n%016lX     \n%016lX     \n%016lX     \n%016lX \n\n", c0, c1, c2, c3, c4);
*/

uint32_t rnd = 0xA0036011;
uint32_t rnd_old;
uint64_t c;
uint64_t sum = 0;
double popcntsum =0.0;
int i;

for (i=0;i<100000000;i++){
   rnd_old = rnd;
   rnd = _mm_crc32_u32(rnd, i);      /* simple random generator                                   */
   b = rnd;                          /* the actual value of b has no influence on the performance */
   a = rnd;                          /* `a` has about 50% nonzero bits                            */

#if 1 == 1                           /* reduce number of set bits from about 16 to 3.5                 */
   a = rnd & rnd_old;                                   /* now `a` has about 25 % nonzero bits    */
          /*0bFEDCBA9876543210FEDCBA9876543210 */     
   a = (a & 0b00110000101001000011100010010000) | 1;    /* about 3.5 nonzero bits on average      */                  
#endif   
/*   printf("a = %08X \n", a);                */

//   popcntsum = popcntsum + _mm_popcnt_u32(a); 
                                               /*   3.5 nonz       50%   (time in sec.)  */
//   c = mult_shft_Orient(a, b );              /*      0.81          1.51                  */
//   c = mult_shft_Harold(a, b );              /*      0.84          1.51                */
//   c = mult_shft_Harold_unroll2(a, b );      /*      0.64          1.58                */
//   c = mult_shft_Harold_unroll4(a, b );      /*      0.48          1.34                */
   c = mult_shft_AVX2(a, b );                /*      0.44          0.40                */
   sum = sum + c;
}
printf("sum = %016lX \n\n", sum);

printf("average density = %f bits per uint32_t\n\n", popcntsum/100000000);

return 0;
}

函数mult_shft_AVX2()使用从零开始的位索引!就像哈罗德的回答一样。看起来在您的问题中,您从 1 开始计算位而不是 0。您可能希望将答案乘以 2 以获得相同的结果。

于 2018-03-07T22:28:06.403 回答