考虑到您要计算 64 位和 128 位无符号数相乘结果的低 128 位,并且您可用的最大乘法是类似 C 的 64 位乘法,它需要两个 64 位无符号输入并返回结果的低 64 位。
需要多少次乘法?
当然,您可以使用 8 次:将所有输入分解为 32 位块并使用 64 位乘法来执行 4 * 2 = 8 所需的全宽 32*32->64 乘法,但可以做得更好?
当然,该算法应该只在乘法之上执行“合理”数量的加法或其他基本算术(我对将乘法重新发明为加法循环并因此声称“零”乘法的解决方案不感兴趣)。
考虑到您要计算 64 位和 128 位无符号数相乘结果的低 128 位,并且您可用的最大乘法是类似 C 的 64 位乘法,它需要两个 64 位无符号输入并返回结果的低 64 位。
需要多少次乘法?
当然,您可以使用 8 次:将所有输入分解为 32 位块并使用 64 位乘法来执行 4 * 2 = 8 所需的全宽 32*32->64 乘法,但可以做得更好?
当然,该算法应该只在乘法之上执行“合理”数量的加法或其他基本算术(我对将乘法重新发明为加法循环并因此声称“零”乘法的解决方案不感兴趣)。
四,但它开始变得有点棘手。
设a和b是要相乘的数,a 0和a 1分别是a的低 32 位和高 32 位,b 0、b 1、b 2、b 3是b的 32 位组,从分别从低到高。
所需结果是 ( a 0 + a 1 •2 32 ) • ( b 0 + b 1 •2 32 + b 2 •2 64 + b 3 •2 96 ) 模 2 128的余数。
我们可以将其重写为 ( a 0 + a 1 •2 32 ) • ( b 0 + b 1 •2 32 ) + ( a 0 + a 1 •2 32 ) • ( b 2 •2 64 + b 3 •2 96 ) 模 2 128。
后一项模 2 128的余数可以计算为单个 64 位乘 64 位乘法(其结果隐式乘以 2 64)。
然后可以使用仔细实现的Karatsuba步骤通过三个乘法计算前一项。简单版本将涉及 33 位乘 33 位到 66 位产品,这是不可用的,但有一个更棘手的版本可以避免它:
z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
最后一行只包含一个乘法;其他两个伪乘法只是条件否定。绝对差分和条件否定在纯 C 中实现很烦人,但可以做到。
当然,没有 Karatsuba,5 倍增。
Karatsuba 很棒,但现在 64 x 64 乘法可以在 3 个时钟内完成,并且每个时钟都可以安排一个新的乘法。因此,处理符号和其他符号的开销可能比节省一个乘法的开销要大得多。
对于简单的 64 x 64 乘法需求:
r0 = a0*b0
r1 = a0*b1
r2 = a1*b0
r3 = a1*b1
where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry
where the carry is the carry from the additions to r0, and result is r3:r0.
typedef struct { uint64_t w0, w1 ; } uint64x2_t ;
uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
uint64x2_t r ;
uint64_t r1, r2, rx, ry ;
uint32_t x1, x0 ;
uint32_t m1, m0 ;
x1 = (uint32_t)(x >> 32) ;
x0 = (uint32_t)x ;
m1 = (uint32_t)(m >> 32) ;
m0 = (uint32_t)m ;
r1 = (uint64_t)x1 * m0 ;
r2 = (uint64_t)x0 * m1 ;
r.w0 = (uint64_t)x0 * m0 ;
r.w1 = (uint64_t)x1 * m1 ;
rx = (uint32_t)r1 ;
rx = rx + (uint32_t)r2 ; // add the ls halves, collecting carry
ry = r.w0 >> 32 ; // pick up ms of r0
r.w0 += (rx << 32) ; // complete r0
rx += ry ; // complete addition, rx >> 32 == carry !
r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;
return r ;
}
对于 Karatsuba,建议:
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
比看起来更棘手...首先,如果z1
是 64 位,则需要以某种方式收集此加法可以生成的进位...这因签名问题而变得复杂。
z0 = a0*b0
z1 = ax*bx -- ax = (a1 - a0), bx = (b0 - b1)
z2 = a1*b1
where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry
where the carry is the carry from the additions to create r0, and result is r1:r0.
where must take into account the signed-ness of ax, bx and z1.
uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
uint64_t a0, a1, b0, b1 ;
uint64_t ax, bx, zx, zy ;
uint as, bs, xs ;
uint64_t z0, z2 ;
uint64x2_t r ;
a0 = (uint32_t)a ; a1 = a >> 32 ;
b0 = (uint32_t)b ; b1 = b >> 32 ;
z0 = a0 * b0 ;
z2 = a1 * b1 ;
ax = (uint64_t)(a1 - a0) ;
bx = (uint64_t)(b0 - b1) ;
as = (uint)(ax > a1) ; // sign of magic middle, a
bs = (uint)(bx > b0) ; // sign of magic middle, b
xs = (uint)(as ^ bs) ; // sign of magic middle, x = a * b
ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ; // abs magic middle a
bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ; // abs magic middle b
zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
xs = xs & (uint)(zx != 0) ; // discard sign if z1 == 0 !
zy = (uint32_t)zx ; // start ls half of z1
zy = zy + (uint32_t)z0 + (uint32_t)z2 ;
r.w0 = z0 + (zy << 32) ; // complete ls word of result.
zy = zy + (z0 >> 32) ; // complete carry
zx = (zx >> 32) - ((uint64_t)xs << 32) ; // start ms half of z1
r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;
return r ;
}
我做了一些非常简单的计时(使用times()
,在 Ryzen 7 1800X 上运行):
...所以,是的,您可以通过使用 Karatsuba 来保存乘法,但是否值得这样做取决于。