2

我有两个数组,uint8_t它们都有 64 个元素。我想出的“最佳”方法是计算所有元素的 SAD,是加载 4x 16 个元素,将它们放入两个m128i寄存器中,然后将它们都放入一个m256寄存器中。这对两个uint8_t数组都完成,如下所示:

__m128i a1, a2, b1, b2, s1, s2;
__m256i u, v, c;

// 128 bit of data x 2
a1 = _mm_set_epi64(*(__m64*)block1, *((__m64*)(block1 + stride)));
block1 += stride + stride;
a2 = _mm_set_epi64(*(__m64*)block1, *((__m64*)(block1 + stride)));

// the upper 128 bits of the result are undefined
u = _mm256_castsi128_si256(a1);
// Copy a to dst, then insert 128 bits from b into dst at the location specified by imm.
u = _mm256_insertf128_si256(u, a2, 0x1);

b1 = _mm_set_epi64(*(__m64*)block2, *((__m64*)(block2 + stride)));
block2 += stride + stride;
b2 = _mm_set_epi64(*(__m64*)block2, *((__m64*)(block2 + stride)));

// the upper 128 bits of the result are undefined
v = _mm256_castsi128_si256(b1);
// Copy a to dst, then insert 128 bits from b into dst at the location specified by imm.
v = _mm256_insertf128_si256(v, b2, 0x1);

我现在有两个m256寄存器,uv,并且可以计算 SAD:

c = _mm256_sad_epu8(u, v);

但是,可能由于时间较晚,我想不出更好的方法来获得结果......这就是我现在得到的:

s1 = _mm256_extractf128_si256(c, 0x0);
s2 = _mm256_extractf128_si256(c, 0x1);

int p, q;
p = _mm_extract_epi32(s1, 0x0);
q = _mm_extract_epi32(s1, 0x2);
*result += p + q;

p = _mm_extract_epi32(s2, 0x0);
q = _mm_extract_epi32(s2, 0x2);
*result += p + q;

result是一个 int,如果不清楚的话。

这会产生相当多的指令。在我看来,这是加载我想要的所有 uint 的唯一方法。但是,这可能不是将结果从m256i c寄存器中取出的最佳方法。

你说什么?你能帮我以更优化的方式做到这一点吗?

放在一起,函数看起来像:

void foobar(uint8_t *block1, uint8_t *block2, int stride, int *result)
{
  *result = 0;
  int i;
  __m128i a1, a2, b1, b2, s1, s2;
  __m256i u, v, c;

  for (i = 0; i < 2; ++i) {
    // loading of uints
    // calculating SAD, and getting result

    block1 += stride; block2 += stride;
    block1 += stride; block2 += stride;
  }
}

由于 uint 的组织方式的性质,我一次只能加载 8 个,然后我必须用stride. 一次加载即十六个,会产生不好的结果。

4

3 回答 3

1

关于从两个字节数组中获取绝对差异的总和,我将如何使用 SSE:

__m128i sum1 = _mm_sad_epu8(u,v);
__m128i sum2 = _mm_shuffle_epi32(sum1,2);
__m128i sum3 = _mm_add_epi16(sum1,sum2);
int8_t  sum4 = (int8_t)_mm_cvtsi128_si32(sum3);

我现在无法在 AVX2 上对此进行测试,但这是未经测试的代码,我会先尝试

__m256i sum1 = _mm256_sad_epu8(u,v);
__m256i sum2 = _mm256_shuffle_epi32(sum1,2);
__m256i sum3 = _mm256_add_epi16(sum1,sum2);  
__m128i sum4 = _mm_add_epi16(_mm256_castsi256_si128(sum3),
_mm256_extracti128_si256(sum3,1));
int8_t  sum5 = (int8_t)_mm_cvtsi128_si32(sum4);

我可以稍后测试这个。

于 2014-09-05T08:16:15.637 回答
1

鉴于上面的评论讨论,我起草了一个使用该VMPSADBW指令进行 8x8 运动估计的工作示例。我对 GCC-4.8.1 为此生成的内容有些失望,但这是一个非常好的开始。它包括两个测试来验证功能以及演示我的新功能的使用sad_block_8x8_range()

VMPSADBW内部循环使用 8 个加载、8秒、7 个垂直添加、一个 shuffle 和一个缩减添加来计算 8x8 块与原始图像中的 8 个重叠块的 SAD 。用 屏蔽掉无效| 0xFFFFUSAD 后,立即提供最低 SAD 及其索引PHMINPOSUW,该指令提供寄存器中八个无符号 16 位值中最低值的最小值和索引,xmm如果此 SAD 甚至低于当前最佳,它与所述索引一起保存。

/* Includes */
#include <stdint.h>
#include <string.h>
#include <stdio.h>
#include <immintrin.h>




/* Typedefs */
typedef uint8_t  u8;
typedef uint16_t u16;
typedef uint32_t u32;
typedef uint64_t u64;






/* Functions */

/**
 * 
 * 
 * @param [in]  orig   A pointer into the image within which to run ME. Points
 *                     to a base offset from which a window of maxDx pixels to
 *                     the right and maxDy pixels down is explored to find the
 *                     lowest SAD.
 * @param [in]  os     The span of the original image.
 * @param [in]  ref    A pointer to the 8x8 reference block being SAD-ed for in
 *                     the original image.
 * @param [in]  rs     The span of the 8x8 reference block.
 * @param [in]  maxDx  The width of the search window for ME. Cannot be 0.
 * @param [in]  maxDy  The height of the search window for ME. Cannot be 0.
 * @param [out] sadOut The lowest SAD found.
 * @param [out] dxOut  The corresponding best vector found, x-coordinate.
 * @param [out] dyOut  The corresponding best vector found, y-coordinate.
 */

void sad_block_8x8_range(const u8* orig,
                         unsigned  os,
                         const u8* ref,
                         unsigned  rs,
                         unsigned  maxDx,
                         unsigned  maxDy,
                         unsigned* sadOut,
                         unsigned* dxOut,
                         unsigned* dyOut){
    __m128  tmp01f, tmp23f, tmp45f, tmp67f;
    __m128i tmp01,  tmp23,  tmp45,  tmp67;
    __m256i r01,    r23,    r45,    r67;
    __m256i o0, o1, o2, o3, o4, o5, o6, o7;
    const u8* refTmp;
    const u8* origTmp;
    int i;

    unsigned tmpDx, dx, dy, sad;
    unsigned minDx = 0, minDy = 0, minSAD = 0xFFFF;
    const static u16 MASKTBLw[] = {
        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF,
        0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0xFFFF,
    };
    const static __m128i* MASKTBL = (const __m128i*)MASKTBLw;

    /* Load the eight rows of 8 bytes of the reference block. */
    refTmp = ref;
    tmp01f = _mm_loadl_pi(tmp01f, (const __m64*)(refTmp));refTmp+=rs;/* tmp_a = [ x x x x x x x x 7 6 5 4 3 2 1 0 ] */
    tmp01f = _mm_loadh_pi(tmp01f, (const __m64*)(refTmp));refTmp+=rs;/* tmp_a = [ f e d c b a 9 8 7 6 5 4 3 2 1 0 ] */
    tmp23f = _mm_loadl_pi(tmp23f, (const __m64*)(refTmp));refTmp+=rs;
    tmp23f = _mm_loadh_pi(tmp23f, (const __m64*)(refTmp));refTmp+=rs;
    tmp45f = _mm_loadl_pi(tmp45f, (const __m64*)(refTmp));refTmp+=rs;
    tmp45f = _mm_loadh_pi(tmp45f, (const __m64*)(refTmp));refTmp+=rs;
    tmp67f = _mm_loadl_pi(tmp67f, (const __m64*)(refTmp));refTmp+=rs;
    tmp67f = _mm_loadh_pi(tmp67f, (const __m64*)(refTmp));
    tmp01  = _mm_castps_si128(tmp01f);/* A cast is needed to integer. */
    tmp23  = _mm_castps_si128(tmp23f);
    tmp45  = _mm_castps_si128(tmp45f);
    tmp67  = _mm_castps_si128(tmp67f);

    /**
     * Combine them into 4 ymm registers each holding two rows in duplicate;
     * One in high half and once in low half.
     */

    r01  = _mm256_inserti128_si256(_mm256_castsi128_si256(tmp01), tmp01, 1);/* r_ab = [ f e d c b a 9 8 7 6 5 4 3 2 1 0 f e d c b a 9 8 7 6 5 4 3 2 1 0 ] */
    r23  = _mm256_inserti128_si256(_mm256_castsi128_si256(tmp23), tmp23, 1);
    r45  = _mm256_inserti128_si256(_mm256_castsi128_si256(tmp45), tmp45, 1);
    r67  = _mm256_inserti128_si256(_mm256_castsi128_si256(tmp67), tmp67, 1);

    /* Iterate over x and y of search space. */
    for(dy=0;dy<maxDy;dy++){
        for(dx=0;dx<maxDx;dx+=8){
            /* Broadcast 16-byte rows to both halves of ymm register */
            origTmp = orig + dy*os + dx;
            o0 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o1 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o2 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o3 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;
            o6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));origTmp += os;

/**
 * Define to 0 if the image can be guaranteed to always have 8 extra allocated
 * bytes beyond its nominal end.
 */
#define NO_OVERALLOCATION 1
            if(NO_OVERALLOCATION && maxDx-dx < 9){/* i.e., maxDx+7-dx < 16, the load size. */
                /**
                 * Special-case code for last row.
                 *      maxDx+7-dx   is the number of bytes to be loaded.
                 *      maxDx-dx     is the number of valid elements.
                 */
#if 1
                __m128i dealigned = _mm_loadu_si128((const __m128i*)(origTmp+maxDx-dx-9));
                __m128i shufmsk   = _mm_add_epi8(_mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0),
                                                 _mm_set1_epi8(7));
                shufmsk = _mm_add_epi8(shufmsk, _mm_set1_epi8(maxDx-dx));
                o7 = _mm256_broadcastsi128_si256(_mm_shuffle_epi8(dealigned, shufmsk));
#else
                u8 tmpArr[16] = {0};
                for(i=0;i < maxDx+7-dx;i++){
                    tmpArr[i] = (orig+(dy+7)*os+dx)[i];
                }
                o7 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(tmpArr)));
#endif
            }else{
                o7 = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(origTmp)));
            }


            /**
             * ACTUAL ACTION.
             * 
             * The upper 128-bit lane calculates the SAD for the right 4 bytes
             * of each row of the reference block, while the lower 128-bit lane
             * does similarly for the left 4 bytes of each row of the reference
             * block.
             * 
             * Once the individual SADs for each 4-byte half of each row are
             * obtained against eight consecutive neighbours, add the sixteen
             * 4-byte row halves to get the SADs for the full 8x8 blocks.
             * 
             * After masking for invalid entries, find the minimum SAD and its
             * index using PHMINPOSUW.
             * 
             * Compare the old to the new SAD and if it is a record-setter, save
             * it.
             */

            /* MPSADBWs */
            __m256i s0 = _mm256_mpsadbw_epu8(o0, r01, 1<<5 | 1<<3 | 0<<2 | 0<<0);
            __m256i s1 = _mm256_mpsadbw_epu8(o1, r01, 1<<5 | 3<<3 | 0<<2 | 2<<0);
            __m256i s2 = _mm256_mpsadbw_epu8(o2, r23, 1<<5 | 1<<3 | 0<<2 | 0<<0);
            __m256i s3 = _mm256_mpsadbw_epu8(o3, r23, 1<<5 | 3<<3 | 0<<2 | 2<<0);
            __m256i s4 = _mm256_mpsadbw_epu8(o4, r45, 1<<5 | 1<<3 | 0<<2 | 0<<0);
            __m256i s5 = _mm256_mpsadbw_epu8(o5, r45, 1<<5 | 3<<3 | 0<<2 | 2<<0);
            __m256i s6 = _mm256_mpsadbw_epu8(o6, r67, 1<<5 | 1<<3 | 0<<2 | 0<<0);
            __m256i s7 = _mm256_mpsadbw_epu8(o7, r67, 1<<5 | 3<<3 | 0<<2 | 2<<0);

            /* Accumulate half-row results together into half-block results */
            s0 = _mm256_add_epi16(s0, s1);
            s0 = _mm256_add_epi16(s0, s2);
            s0 = _mm256_add_epi16(s0, s3);
            s0 = _mm256_add_epi16(s0, s4);
            s0 = _mm256_add_epi16(s0, s5);
            s0 = _mm256_add_epi16(s0, s6);
            s0 = _mm256_add_epi16(s0, s7);

            /* Accumulate half-block results into block results */
            __m128i t0 = _mm256_extracti128_si256(s0, 0);
            __m128i t1 = _mm256_extracti128_si256(s0, 1);
            __m128i t  = _mm_add_epi16(t0, t1);

            /* Find horizontal minimum using PHMINPOSUW */
            __m128i hm = maxDx-dx < 8 ? MASKTBL[maxDx-dx] : _mm_setzero_si128();
            __m128i h  = _mm_minpos_epu16(_mm_or_si128(t, hm));
            sad   =      (u16)_mm_extract_epi16(h, 0);
            tmpDx = dx + (u16)_mm_extract_epi16(h, 1);

            /* Save the result if it is the best so far. */
            if(sad < minSAD){
                minDx  = tmpDx;
                minDy  = dy;
                minSAD = sad;
            }
        }
    }

    sadOut && (*sadOut = minSAD);
    dxOut  && (*dxOut  = minDx);
    dyOut  && (*dyOut  = minDy);
}

/**
 * MAIN.
 * 
 * Runs two testcases.
 */

int main(){
    const u8 ref[] = {
        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
        0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
        0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
        0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
        0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
        0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
        0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
        0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
    };
    const u8 img0[] = {
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
    };
    const u8 img1[] = {
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x40,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
    };
    unsigned sad, dx, dy;


    sad_block_8x8_range(img0, 16, ref, 8, 9, 9, &sad, &dx, &dy);
    if(sad == 0 && dx == 7 && dy == 3){
        printf("Test 1 PASSED!\n");
    }else{
        printf("Test 1 FAILED! (SAD = %u, MV=(%u, %u))\n", sad, dx, dy);
    }


    sad_block_8x8_range(img1, 16, ref, 8, 9, 9, &sad, &dx, &dy);
    if(sad == 1 && dx == 8 && dy == 4){
        printf("Test 2 PASSED!\n");
    }else{
        printf("Test 2 FAILED! (SAD = %u, MV=(%u, %u))\n", sad, dx, dy);
    }

    return 0;
}
于 2014-09-06T08:33:56.463 回答
0

为了澄清,我假设您要优化以下代码:

int sad_2x64_normal(uint8_t *ptr0, uint8_t *ptr1)
{
  int  sum = 0;
  int  v0,v1;  
  for(int i = 0; i < 64; ++i) {
    v0 = static_cast<int>(*ptr0++);
    v1 = static_cast<int>(*ptr1++);
    sum += abs(v0-v1);
  }
  return sum;
}

这个问题用 AVX2 标记,我的简短解决方案是

int sad_2x64_avx2(uint8_t *ptr0, uint8_t *ptr1)
{
  register __m256i  r0;
  register __m256i  r1;
  register __m256i  r2;
  register __m256i  r3;

  r0 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr0)); // load 32 bytes (aligned)
  r1 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr1)); // load 32 bytes (aligned)

  r2 = _mm256_sad_epu8(r0, r1);    // 4 unsigned 64 bit value

  r0 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr0+32));
  r1 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr1+32));

  r3 = _mm256_sad_epu8(r0, r1);   

  r2 = _mm256_add_epi16(r2, r3);
  r2 = _mm256_shuffle_epi32(r2, 0xE8);    
  r2 = _mm256_hadd_epi32(r2, r2);
  r2 = _mm256_permute4x64_epi64(r2, 0xE8);
  r2 = _mm256_shuffle_epi32(r2, 0xE8);   
  r2 = _mm256_hadd_epi32(r2, r2);    
  return _mm_extract_epi16(_mm256_castsi256_si128(r2), 0);
}

相同,但有所有评论

#include "emmintrin.h"
#include "immintrin.h"

// 
// returns \sum_0^{63} abs(ptr0[i]-ptr1[i])
// assume ptr0 and ptr1 are 32 byte aligned
//
int sad_2x64_avx2(uint8_t *ptr0, uint8_t *ptr1)
{
  register __m256i  r0;
  register __m256i  r1;
  register __m256i  r2;
  register __m256i  r3;

  // 1st 32 bytes

  r0 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr0));    // load 32 bytes (aligned)
  r1 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr1));    // load 32 bytes (aligned)

  r2 = _mm256_sad_epu8(r0, r1);    // results stored as 4x64

  // 2nd 32 bytes

  r0 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr0+32));    // load 32 bytes (aligned)
  r1 = _mm256_load_si256(reinterpret_cast<__m256i const *>(ptr1+32));    // load 32 bytes (aligned)

  r3 = _mm256_sad_epu8(r0, r1);


  // after sad_epu8
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = | 0 | 0 | 0 | a | 0 | 0 | 0 | b | 0 | 0 | 0 | c | 0 | 0 | 0 | d |
  // r3 = | 0 | 0 | 0 | e | 0 | 0 | 0 | f | 0 | 0 | 0 | g | 0 | 0 | 0 | h | 

  r2 = _mm256_add_epi16(r2, r3); 

  // after add_epi16
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = | 0 | 0 | 0 | i | 0 | 0 | 0 | j | 0 | 0 | 0 | k | 0 | 0 | 0 | l |

  r2 = _mm256_shuffle_epi32(r2, 0xE8); // binary 11_10_10_00

  // after shuffle
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = | 0 | 0 | 0 | i | 0 | i | 0 | j | 0 | 0 | 0 | k | 0 | k | 0 | l | 

  r2 = _mm256_hadd_epi32(r2, r2);  

  // after hadd
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = |     i |   i+j |     i |  i+j  |     k |   k+l |     k |   k+l | 

  r2 = _mm256_permute4x64_epi64(r2, 0xE8); // 11_10_10_00
  r2 = _mm256_shuffle_epi32(r2, 0xE8); // binary 11_10_10_00

  // after permute and shuffle
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = |     i |   i+j |   i+j |  i+j  |     i |   i+j |   i+j |   k+l | 

  r2 = _mm256_hadd_epi32(r2, r2);

  // after hadd and shuffle
  //
  //       255                             127             63      31
  //      |                               |               |       |
  // r2 = | ....  | ....  | ....  | ....  |  .... | ....  | ....  |i+j+k+l| 


  return _mm_extract_epi16(_mm256_castsi256_si128(r2), 0);
}

希望能帮助到你!

在我的机器(Haswell Core i7 4900MQ)上使用 g++ 4.8.2

g++ -march=core-avx2 -Wall -Wextra -std=c++11 -O3 ...

我观察到普通版本和 avx2 版本之间的 x58 加速!

于 2014-09-24T14:20:00.563 回答