1

我试图了解 Karatsuba 乘法算法。我编写了以下代码:

def karatsuba_multiply(x, y):
    # split x and y
    len_x = len(str(x))
    len_y = len(str(y))
    if len_x == 1 or len_y == 1:
        return x*y

    n = max(len_x, len_y)
    n_half = 10**(n // 2)
    a = x // n_half
    b = x % n_half
    c = y // n_half
    d = y % n_half

    ac = karatsuba_multiply(a, c)
    bd = karatsuba_multiply(b, d)
    ad_plus_bc = karatsuba_multiply((a+b), (c+d)) - ac - bd

    return (10**n * ac) + (n_half * ad_plus_bc) + bd

此测试用例不起作用:

print(karatsuba_multiply(1234, 5678)) ## returns 11686652, should be 7006652‬

但是,如果我使用此答案中的以下代码,则测试用例会产生正确的答案:

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m // 2

        a = x // 10**(m2)
        b = x % 10**(m2)
        c = y // 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)

这两个函数看起来都在做同样的事情。为什么我的不工作?

4

1 回答 1

1

似乎在kerat_multiply实施过程中,您不能使用正确的公式进行最后一次返回。

在原始kerat实现中,值m2 = m // 2在最后一次返回中乘以 2 (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0) (2*m2)

所以你我认为你需要在下面添加一个新变量,n2 == n // 2以便你可以在最后一次返回中将它乘以 2,或者使用原始实现。

希望它有帮助:)

编辑:这是由2 * n // 2不同于2 * (n // 2)

n = max(len_x, len_y)
n_half = 10**(n // 2)
n2 = n // 2
a = x // n_half
b = x % n_half
c = y // n_half
d = y % n_half

ac = karatsuba_multiply(a, c)
bd = karatsuba_multiply(b, d)
ad_plus_bc = karatsuba_multiply((a + b), (c + d)) - ac - bd

return (10**(2 * n2) * ac) + (n_half * (ad_plus_bc)) + bd
于 2020-06-18T12:35:14.893 回答