14

AXV2 没有任何大于 32 位源的整数乘法。它确实提供32 x 32 -> 32乘法,以及32 x 32 -> 64乘法1,但没有 64 位源。

假设我需要输入大于 32 位但小于或等于 52 位的无符号乘法 - 我可以简单地使用浮点DP 乘法或 FMA 指令,并且当整数输入和结果可以用 52 位或更少的位表示(即,在 [0, 2^52-1] 范围内)?

我想要产品的所有 104 位的更一般的情况怎么样?或者整数乘积超过 52 位的情况(即,乘积在位索引 > 52 中具有非零值) - 但我只想要低 52 位?在后一种情况下,MUL它将给我更高的位并舍入一些较低的位(也许这就是 IFMA 的帮助?)。

编辑:事实上,根据这个答案,它也许可以做任何高达 2^53 的事情——我忘记1了尾数之前的隐含前导有效地给了你一点。


1有趣的是,正如 Mysticial在评论中解释的那样,64 位产品PMULDQ操作的延迟是 32 位版本的一半,吞吐量是 32 位版本的两倍。PMULLD

4

3 回答 3

13

是的,这是可能的。但从 AVX2 开始,它不太可能比 MULX/ADCX/ADOX 的标量方法更好。

对于不同的输入/输出域,这种方法几乎有无限数量的变化。我只会介绍其中的 3 个,但是一旦你知道它们是如何工作的,它们就很容易概括。

免责声明:

  • 这里的所有解决方案都假设舍入模式是舍入到偶数。
  • 不建议使用快速数学优化标志,因为这些解决方案依赖于严格的 IEEE。

范围内的带符号双打: [-2 51 , 2 51 ]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256d& L, __m256d& H, __m256d A, __m256d B){
    const __m256d ROUND = _mm256_set1_pd(30423614405477505635920876929024.);    //  3 * 2^103
    const __m256d SCALE = _mm256_set1_pd(1. / 4503599627370496);                //  1 / 2^52

    //  Multiply and add normalization constant. This forces the multiply
    //  to be rounded to the correct number of bits.
    H = _mm256_fmadd_pd(A, B, ROUND);

    //  Undo the normalization.
    H = _mm256_sub_pd(H, ROUND);

    //  Recover the bottom half of the product.
    L = _mm256_fmsub_pd(A, B, H);

    //  Correct the scaling of H.
    H = _mm256_mul_pd(H, SCALE);
}

这是最简单的一种,也是唯一一种与标量方法竞争的方法。最终缩放是可选的,具体取决于您要对输出执行的操作。所以这可以被认为只有3条指令。但它也是最没用的,因为输入和输出都是浮点值。

两个 FMA 保持融合是绝对关键的。这就是快速数学优化可以破坏事物的地方。如果第一个 FMA 被分解,则L不再保证在范围内[-2^51, 2^51]。如果第二个 FMA 被打破,那L将是完全错误的。


范围内的有符号整数: [-2 51 , 2 51 ]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(6755399441055744);     //  3*2^51
    const __m256d CONVERT_D = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_add_epi64(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_add_epi64(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_sub_epi64(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_D);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_D));
}

在第一个示例的基础上,我们将它与快速double <-> int64转换技巧的通用版本结合起来。

这个更有用,因为您正在使用整数。但即使使用快速转换技巧,大部分时间都将用于转换。幸运的是,如果您多次乘以相同的操作数,您可以消除一些输入转换。


范围内的无符号整数: [0, 2 52 )

//  A*B = L + H*2^52
//  Input:  A and B are in the range [0, 2^52)
//  Output: L and H are in the range [0, 2^52)
void mul52_unsigned(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(4503599627370496);     //  2^52
    const __m256d CONVERT_D = _mm256_set1_pd(1);
    const __m256d CONVERT_S = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_or_si256(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_or_si256(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_xor_si256(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_S);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_S));

    //  Make Correction
    H = _mm256_sub_epi64(H, _mm256_srli_epi64(L, 63));
    L = _mm256_and_si256(L, _mm256_set1_epi64x(0x000fffffffffffff));
}

最后我们得到了原始问题的答案。这通过调整转换和添加校正步骤来构建有符号整数解决方案。

但在这一点上,我们有 13 条指令——其中一半是高延迟指令,这还不包括众多的FP <-> int旁路延迟。因此,这不太可能赢得任何基准。相比之下,64 x 64 -> 128-bitSIMD 乘法可以在 16 条指令中完成(如果您对输入进行预处理,则需要 14 条。)

如果舍入模式是向下舍入或舍入为零,则可以省略校正步骤。唯一重要的指令是h = _mm256_fmadd_pd(a, b, CONVERT_U);. 因此,在 AVX512 上,您可以覆盖该指令的舍入,并单独保留舍入模式。


最后的想法:

值得注意的是,2 52的运算范围可以通过调整魔法常数来减小。这对于第一个解决方案(浮点解决方案)可能很有用,因为它为您提供了额外的尾数以用于浮点累加。这使您无需像最后两个解决方案一样在 int64 和 double 之间不断地来回转换。

虽然这里的 3 个示例不太可能比标量方法更好,但 AVX512 几乎肯定会打破平衡。Knights Landing 的 ADCX 和 ADOX 吞吐量尤其差。

当然,当 AVX512-IFMA 出现时,所有这些都没有实际意义。这将一个完整的52 x 52 -> 104-bit产品减少到 2 条指令,并免费提供累积。

于 2017-01-07T10:50:30.137 回答
3

进行多字整数算术的一种方法是使用双双算术。让我们从一些双倍乘法代码开始

#include <math.h>
typedef struct {
  double hi;
  double lo;
} doubledouble;

static doubledouble quick_two_sum(double a, double b) {
  double s = a + b;
  double e = b - (s - a);
  return (doubledouble){s, e};
}

static doubledouble two_prod(double a, double b) {
  double p = a*b;
  double e = fma(a, b, -p);
  return (doubledouble){p, e};
}

doubledouble df64_mul(doubledouble a, doubledouble b) {
  doubledouble p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
}

该函数two_prod可以在两条指令中执行整数 53bx53b -> 106b。该函数df64_mul可以做整数 106bx106b -> 106b。

让我们将其与整数 128bx128b -> 128b 与整数硬件进行比较。

__int128 mul128(__int128 a, __int128 b) {
  return a*b;
}

大会为mul128

imul    rsi, rdx
mov     rax, rdi
imul    rcx, rdi
mul     rdx
add     rcx, rsi
add     rdx, rcx

df64_mul(编译时gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off)的程序集

vmulsd      xmm4, xmm0, xmm2
vmulsd      xmm3, xmm0, xmm3
vmulsd      xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd      xmm3, xmm3, xmm0
vaddsd      xmm1, xmm3, xmm1
vaddsd      xmm0, xmm1, xmm4
vsubsd      xmm4, xmm0, xmm4
vsubsd      xmm1, xmm1, xmm4

mul128执行三个标量乘法和两个标量加法/减法,而df64_mul执行 3 个 SIMD 乘法、1 个 SIMD FMA 和 5 个 SIMD 加法/减法。我没有分析这些方法,但对我来说,使用每个 AVX 寄存器的 4-doubles (更改为和)df64_mul可以胜过这似乎不是不合理的。mul128sdpdxmmymm


很容易说问题是切换回整数域。但为什么这是必要的?您可以在浮点域中做任何事情。让我们看一些例子。float我发现使用 进行单元测试比使用更容易double

doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = fma(a, b, -p);
  return (doublefloat){p, e};
}

//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05  
//hi = 15395627991040, lo = 102575, s = 15395628093615

//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo 
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488

所以我们最终得到不同的范围,在第二种情况下,误差 ( e) 甚至是负数,但总和仍然是正确的。我们甚至可以将两个 doublefloat 值x相加y(一旦我们知道如何进行 double-double 加法 - 请参见最后的代码)并得到15395628093615+2178594202488. 无需对结果进行标准化。

但是加法带来了双双算术的主要问题。也就是说,加法/减法很慢,例如 128b+128b -> 128b至少需要 11 个浮点加法,而对于整数,它只需要两个 (addadc)。

因此,如果一个算法重于乘法但轻于加法,那么使用 double-double 进行多字整数运算可能会获胜。


作为旁注,C 语言足够灵活,可以实现整数完全通过浮点硬件实现的实现。 int可以是 24 位(来自单个浮点),long也可以是 54 位。(来自双浮点),并且long long可能是 106 位(来自双双)。C 甚至不需要二进制补码,因此整数可以像通常使用浮点一样对负数使用带符号的幅度。


这是带有双倍乘法和加法的工作 C 代码(我还没有实现除法或其他操作,例如,sqrt但有论文显示如何做到这一点)以防有人想玩它。看看这是否可以针对整数进行优化会很有趣。

//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma 
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>

//#include <float.h>

typedef struct {
  float hi;
  float lo;
} doublefloat;

typedef union {
  float f;
  int i;
  struct {
    unsigned mantisa : 23;
    unsigned exponent: 8;
    unsigned sign: 1;
  };
} float_cast;

void print_float(float_cast a) {
  printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}

void print_doublefloat(doublefloat a) {
  float_cast hi = {a.hi};
  float_cast lo = {a.lo};
  printf("hi: "); print_float(hi);
  printf("lo: "); print_float(lo);
}

doublefloat quick_two_sum(float a, float b) {
  float s = a + b;
  float e = b - (s - a);
  return (doublefloat){s, e};
  // 3 add
}

doublefloat two_sum(float a, float b) {
  float s = a + b;
  float v = s - a;
  float e = (a - (s - v)) + (b - v);
  return (doublefloat){s, e};
  // 6 add 
}

doublefloat df64_add(doublefloat a, doublefloat b) {
  doublefloat s, t;
  s = two_sum(a.hi, b.hi);
  t = two_sum(a.lo, b.lo);
  s.lo += t.hi;
  s = quick_two_sum(s.hi, s.lo);
  s.lo += t.lo;
  s = quick_two_sum(s.hi, s.lo);
  return s;
  // 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}

doublefloat split(float a) {
  //#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
  float t = (SPLITTER)*a;
  float hi = t - (t - a);
  float lo = a - hi;
  return (doublefloat){hi, lo};
  // 1 mul, 3 add
}

doublefloat split_sse(float a) {
  __m128 k = _mm_set1_ps(4097.0f);
  __m128 a4 = _mm_set1_ps(a);
  __m128 t = _mm_mul_ps(k,a4);
  __m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
  __m128 lo4 = _mm_sub_ps(a4, hi4);
  float tmp[4];
  _mm_storeu_ps(tmp, hi4);
  float hi = tmp[0];
  _mm_storeu_ps(tmp, lo4);
  float lo = tmp[0];
  return (doublefloat){hi,lo};

}

float mult_sub(float a, float b, float c) {
  doublefloat as = split(a), bs = split(b);
  //print_doublefloat(as);
  //print_doublefloat(bs);
  return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
  // 4 mul, 4 add, 2 split = 6 mul, 10 add
}

doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = mult_sub(a, b, p);
  return (doublefloat){p, e};
  // 1 mul, one mult_sub
  // 7 mul, 10 add
}

float mult_sub2(float a, float b, float c) {
  doublefloat as = split(a);
  return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}

doublefloat two_sqr(float a) {
  float p = a*a;
  float e = mult_sub2(a, a, p);
  return (doublefloat){p, e};
}

doublefloat df64_mul(doublefloat a, doublefloat b) {
  doublefloat p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
  //two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add 
  //or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}

doublefloat df64_sqr(doublefloat a) {
  doublefloat p = two_sqr(a.hi);
  p.lo += 2*a.hi*a.lo;
  return quick_two_sum(p.hi, p.lo);
}

int float2int(float a) {
  int M = 0xc00000; //1100 0000 0000 0000 0000 0000
  a += M;
  float_cast x;
  x.f = a;
  return x.i - 0x4b400000;
}

doublefloat add22(doublefloat a, doublefloat b) {
  float r = a.hi + b.hi;
  float s = fabsf(a.hi) > fabsf(b.hi) ?
    (((a.hi - r) + b.hi) + b.lo ) + a.lo :
    (((b.hi - r) + a.hi) + a.lo ) + b.lo;
  return two_sum(r, s);  
  //11 add 
}

int main(void) {
  //print_float((float_cast){1.0f});
  //print_float((float_cast){-2.0f});
  //print_float((float_cast){0.0f});
  //print_float((float_cast){3.14159f});
  //print_float((float_cast){1.5f});
  //print_float((float_cast){3.0f});
  //print_float((float_cast){7.0f});
  //print_float((float_cast){15.0f});
  //print_float((float_cast){31.0f});

  //uint64_t t = 0xffffff;
  //print_float((float_cast){1.0f*t});
  //printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);

  /*
    float_cast t1;
    t1.mantisa = 0x7fffff;
    t1.exponent = 0xfe;
    t1.sign = 0;
    print_float(t1);
  */
  //doublefloat z = two_prod(1.0f*t, 1.0f*t);
  //print_doublefloat(z);
  //double z2 = (double)z.hi + (double)z.lo;
  //printf("%.16e\n", z2);
  doublefloat s = {0};
  int64_t si = 0;
  for(int i=0; i<100000; i++) {
    int ai = rand()%0x800, bi = rand()%0x800000;
    float a = ai, b = bi;
    doublefloat z = two_prod(a,b);
    int64_t zi = (int64_t)ai*bi;
    //print_doublefloat(z);
    //s = df64_add(s,z);
    s = add22(s,z);
    si += zi;
    print_doublefloat(z);
    printf("%d %d ", ai,bi);
    int64_t h = z.hi;
    int64_t l = z.lo;
    int64_t t = h+l;
    //if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);

    printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);

    h = s.hi;
    l = s.lo;
    t = h + l;
    //if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);

    if(si > (1LL<<48)) {
      printf("overflow after %d iterations\n", i); break;
    }
  }

  print_doublefloat(s);
  printf("%" PRId64 "\n", si);
  int64_t x = s.hi;
  int64_t y = s.lo;
  int64_t z = x+y;
  //int hi = float2int(s.hi);
  printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}
于 2017-01-11T12:52:51.370 回答
2

好吧,您当然可以对整数进行 FP-lane 操作。而且它们总是准确的:虽然有些 SSE 指令不能保证正确的 IEEE-754 精度和舍入,但它们无一例外都是没有整数范围的指令,所以无论如何都不是你要查看的指令。底线:加法/减法/乘法在整数域中始终是精确的,即使您是在压缩浮点数上执行它们。

至于四精度浮点数(> 52 位尾数),不,不支持这些,并且在可预见的将来可能不会。只是没有太多要求他们。它们出现在一些 SPARC 时代的工作站架构中,但老实说,它们只是开发人员对如何编写数值稳定算法的不完全理解的绷带,并且随着时间的推移它们逐渐淡出。

宽整数运算结果非常不适合 SSE。我最近在实现一个大整数库时真的尝试过利用它,老实说,它对我没有好处。x86 是为多字算术而设计的;您可以在诸如 ADC(它产生并消耗一个进位位)和 IDIV(它允许除数的宽度是被除数的两倍,只要商不比被除数宽)等操作中看到它,这是一个约束对除多字除法之外的任何东西都没用)。但是多字算术本质上是顺序的,而 SSE 本质上是并行的。如果你足够幸运,你的号码刚好够用适合 FP 尾数的位,恭喜。但是,如果您通常有大整数,那么 SSE 可能不会成为您的朋友。

于 2017-01-03T22:51:12.210 回答