14

考虑到您要计算 64 位和 128 位无符号数相乘结果的低 128 位,并且您可用的最大乘法是类似 C 的 64 位乘法,它需要两个 64 位无符号输入并返回结果的低 64 位。

需要多少次乘法?

当然,您可以使用 8 次:将所有输入分解为 32 位块并使用 64 位乘法来执行 4 * 2 = 8 所需的全宽 32*32->64 乘法,但可以做得更好?

当然,该算法应该只在乘法之上执行“合理”数量的加法或其他基本算术(我对将乘法重新发明为加法循环并因此声称“零”乘法的解决方案不感兴趣)。

4

2 回答 2

16

四,但它开始变得有点棘手。

ab是要相乘的数,a 0a 1分别是a的低 32 位和高 32 位,b 0b 1b 2b 3是b的 32 位组,从分别从低到高。

所需结果是 ( a 0 + a 1 •2 32 ) • ( b 0 + b 1 •2 32 + b 2 •2 64 + b 3 •2 96 ) 模 2 128的余数。

我们可以将其重写为 ( a 0 + a 1 •2 32 ) • ( b 0 + b 1 •2 32 ) + ( a 0 + a 1 •2 32 ) • ( b 2 •2 64 + b 3 •2 96 ) 模 2 128

后一项模 2 128的余数可以计算为单个 64 位乘 64 位乘法(其结果隐式乘以 2 64)。

然后可以使用仔细实现的Karatsuba步骤通过三个乘法计算前一项。简单版本将涉及 33 位乘 33 位到 66 位产品,这是不可用的,但有一个更棘手的版本可以避免它:

z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

最后一行只包含一个乘法;其他两个伪乘法只是条件否定。绝对差分和条件否定在纯 C 中实现很烦人,但可以做到。

于 2018-08-17T20:16:04.997 回答
4

当然,没有 Karatsuba,5 倍增。

Karatsuba 很棒,但现在 64 x 64 乘法可以在 3 个时钟内完成,并且每个时钟都可以安排一个新的乘法。因此,处理符号和其他符号的开销可能比节省一个乘法的开销要大得多。

对于简单的 64 x 64 乘法需求:

    r0 =       a0*b0
    r1 =    a0*b1
    r2 =    a1*b0
    r3 = a1*b1

  where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
            and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry

  where the carry is the carry from the additions to r0, and result is r3:r0.

typedef struct { uint64_t w0, w1 ; } uint64x2_t ;

uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
  uint64x2_t r ;
  uint64_t r1, r2, rx, ry ;
  uint32_t x1, x0 ;
  uint32_t m1, m0 ;

  x1    = (uint32_t)(x >> 32) ;
  x0    = (uint32_t)x ;
  m1    = (uint32_t)(m >> 32) ;
  m0    = (uint32_t)m ;

  r1    = (uint64_t)x1 * m0 ;
  r2    = (uint64_t)x0 * m1 ;
  r.w0  = (uint64_t)x0 * m0 ;
  r.w1  = (uint64_t)x1 * m1 ;

  rx    = (uint32_t)r1 ;
  rx    = rx + (uint32_t)r2 ;    // add the ls halves, collecting carry
  ry    = r.w0 >> 32 ;           // pick up ms of r0
  r.w0 += (rx << 32) ;           // complete r0
  rx   += ry ;                   // complete addition, rx >> 32 == carry !

  r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;

  return r ;
}

对于 Karatsuba,建议:

z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

比看起来更棘手...首先,如果z1是 64 位,则需要以某种方式收集此加法可以生成的进位...这因签名问题而变得复杂。

    z0 =       a0*b0
    z1 =    ax*bx        -- ax = (a1 - a0), bx = (b0 - b1)
    z2 = a1*b1

  where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
            and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry

  where the carry is the carry from the additions to create r0, and result is r1:r0.

  where must take into account the signed-ness of ax, bx and z1. 

uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
  uint64_t a0, a1, b0, b1 ;
  uint64_t ax, bx, zx, zy ;
  uint     as, bs, xs ;
  uint64_t z0, z2 ;
  uint64x2_t r ;

  a0 = (uint32_t)a ; a1 = a >> 32 ;
  b0 = (uint32_t)b ; b1 = b >> 32 ;

  z0 = a0 * b0 ;
  z2 = a1 * b1 ;

  ax = (uint64_t)(a1 - a0) ;
  bx = (uint64_t)(b0 - b1) ;

  as = (uint)(ax > a1) ;                // sign of magic middle, a
  bs = (uint)(bx > b0) ;                // sign of magic middle, b
  xs = (uint)(as ^ bs) ;                // sign of magic middle, x = a * b

  ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ;  // abs magic middle a
  bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ;  // abs magic middle b

  zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
  xs = xs & (uint)(zx != 0) ;           // discard sign if z1 == 0 !

  zy = (uint32_t)zx ;                   // start ls half of z1
  zy = zy + (uint32_t)z0 + (uint32_t)z2 ;

  r.w0 = z0 + (zy << 32) ;              // complete ls word of result.
  zy   = zy + (z0 >> 32) ;              // complete carry

  zx   = (zx >> 32) - ((uint64_t)xs << 32) ;   // start ms half of z1
  r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;

  return r ;
}

我做了一些非常简单的计时(使用times(),在 Ryzen 7 1800X 上运行):

  • 使用 gcc __int128........ ~780 'units'
  • 使用 mulu64x2() ..................... ~895
  • 使用 mulu64x2_karatsuba()... ~1,095

...所以,是的,您可以通过使用 Karatsuba 来保存乘法,但是否值得这样做取决于。

于 2020-02-13T20:04:32.053 回答