0

我正在尝试实现递归 Karatsuba 算法。我成功编写了一个乘法递归算法,但它计算了 ad 和 bc。但是,对于这个程序,我尝试记下每个中间值,即 ac、bd、total、sum。许多值并没有作为预期值出现。我无法弄清楚我的代码在哪里搞砸了。我仍然是一名业余程序员,我已经花了几个小时尝试调试,但现在我别无选择,只能在这里发布我的大代码:

#include <iostream>
using namespace std;

int approxLog(int n, int b) {
    return !(n/b < 1) ? 1+approxLog(n/b, b) : 1;
}

int power(int b, int e) {
    return (e < 1) ? 1 : b*power(b, e-1);
}

int odd1(int n) {
    if(n%2 != 0)
        return n-1;
    else
        return n;
}

int odd2(int n) {
    if(n%2 == 0)
        return n/2;
    else
        return n/2 + 1;
}

void num_split (int num, int d, int *a, int *b) {
    int  i = 1, tmp = 0, j = 1, k = 0;
    while (i <= d/2) {
        tmp += (num%10)*power(10, i-1);
        num /= 10;
        i++;
    }
    *b = tmp;
    tmp = 0;
    while (i <= d) {
        tmp += (num%10)*power(10, j-1);
        num /= 10;
        i++;
        j++;
    }
    *a = tmp;
    tmp = 0;
}

long long int multiply(int x, int y, int n) {
    int a = 0, b = 0, c = 0, d = 0;
    int ac = 0, bd = 0, total = 0, sum = 0;
    int *ptr_a = &a;
    int *ptr_b = &b;
    int *ptr_c = &c;
    int *ptr_d = &d;
    num_split(x, n, ptr_a, ptr_b);
    num_split(y, n, ptr_c, ptr_d);

    if((x < 10) || (y < 10)) {
        return x*y;
    }
    else {
        ac = multiply(a, c, odd2(n));
        bd = multiply(b, d, n/2);
        total = multiply((a+b), (c+d), odd2(n));
        // cout << total <<  endl;
        sum = total - ac - bd;
        return power(10, odd1(n))*ac + power(10, n/2))*sum + bd;
    }
}

int main() {
    int x = 1234, y = 1234;
    int n = approxLog(x, 10);
    long long int product = multiply(x, y, n);
    cout << product << endl;

    return 0;
}
4

1 回答 1

0

问题是在每次递归中,你应该取大约。x和中较大者的对数y。因此,您的代码中的以下更改将解决问题(也请注意注释掉的部分!):

long long int multiply(int x, int y)//, int n)
{
    int a = 0, b = 0, c = 0, d = 0;
    int ac = 0, bd = 0, total = 0, sum = 0;
    int *ptr_a = &a;
    int *ptr_b = &b;
    int *ptr_c = &c;
    int *ptr_d = &d;
    int n = (x>y)?approxLog(x, 10):approxLog(y, 10);
    num_split(x, n, ptr_a, ptr_b);
    num_split(y, n, ptr_c, ptr_d);

    if((x < 10) || (y < 10)) {
        return x*y;
    }
    else {
        ac = multiply(a, c);//, odd2(n));
        bd = multiply(b, d);//, n/2);
        total = multiply((a+b), (c+d));//, odd2(n));
        cout << total <<  endl;
        sum = total - ac - bd;
        return power(10, odd1(n))*ac + power(10, (n/2))*sum + bd;
    }
}

int main() {
    int x = 134546, y = 1234;
    //int n = approxLog(x, 10);
    long long int product = multiply(x, y);//, n);
    cout<<"Karatsuba: "<<product<<endl;
    cout<<"Normal:    "<<x*y<<endl;

    return 0;
}
于 2018-05-23T16:40:54.940 回答