0

我正在实施Strassen 的矩阵乘法算法作为分配的一部分。我已经正确编码,但我不知道为什么它会给出分段错误。我将 strassen() 称为strassen(0,n,0,n); 主要。n 是用户给定的数字,它是 2 的幂,它是矩阵(2D 数组)的最大尺寸。对于 n = 4,它没有给出段错误,但对于 n = 8、16、32,它给出了段错误。代码如下所示。

    void strassen(int p, int q, int r, int s)
    {
        int p1,p2,p3,p4,p5,p6,p7;   
        if(((q-p) == 2)&&((s-r) == 2))
        {
            p1 = ((a[p][r] + a[p+1][r+1])*(b[p][r] + b[p+1][r+1]));
            p2 = ((a[p+1][r] + a[p+1][r+1])*b[p][r]);
            p3 = (a[p][r]*(b[p][r+1] - b[p+1][r+1]));
            p4 = (a[p+1][r+1]*(b[p+1][r] - b[p][r]));
            p5 = ((a[p][r] + a[p][r+1])*b[p+1][r+1]);
            p6 = ((a[p+1][r] - a[p][r])*(b[p][r] +b[p][r+1]));
            p7 = ((a[p][r+1] - a[p+1][r+1])*(b[p+1][r] + b[p+1][r+1]));
            c[p][r] = p1 + p4 - p5 + p7;
            c[p][r+1] = p3 + p5;
            c[p+1][r] = p2 + p4;
            c[p+1][r+1] = p1 + p3 - p2 + p6;
        }
        else
        {
            strassen(p, q/2, r, s/2);
            strassen(p, q/2, s/2, s);
            strassen(q/2, q, r, s/2);
            strassen(q/2, q, s/2, s);
        }
    }
4

2 回答 2

2

else 块中的某些条件是无限递归的(至少第二个和第四个,没有检查另一个)。这可以很容易地用笔和纸证明:例如
strassen(p, q/2, s/2, s),对于 `0,8,0,8 将在每次迭代中产生:

1) 0, 4, 4, 8

2) 0, 2, 4, 8

3) 0, 1, 4, 8

4) 0, 0, 4, 8

5) 0, 0, 4, 8

...

and since none of those results pass your

if(((q-p) == 2)&&((s-r) == 2))

test, the function will run (and I suspect branch, as the 4th function has the same problem...) until the end of the stack is hit, causing a Segmentation Fault.

Anyway, if what you are trying to do in the else block is to recursively bisect the matrix, a better attempt would be something like:

strassen(p, (q+p)/2, r, (r+s)/2);                                               
strassen(p, (q+p)/2, (r+s)/2, s);                                               
strassen((q+p)/2,q, (r+s)/2, s);                                                
strassen((q+p)/2,q, r, (r+s)/2);        

(keep in mind that I didn't check this code, though)

于 2013-03-16T16:39:37.627 回答
0
void strassen(int p, int q, int r, int s)
{
    int p1,p2,p3,p4,p5,p6,p7;   
    if(q-p == 2 && s-r == 2)
    {
        p1 = (a[p][r] + a[p+1][r+1])   * (b[p][r] + b[p+1][r+1]);
        p2 = (a[p+1][r] + a[p+1][r+1]) * b[p][r];
        p3 = a[p][r]                   * (b[p][r+1] - b[p+1][r+1]);
        p4 = a[p+1][r+1]               * (b[p+1][r] - b[p][r]);
        p5 = (a[p][r] + a[p][r+1])     * b[p+1][r+1];
        p6 = (a[p+1][r] - a[p][r])     * (b[p][r] +b[p][r+1] );
        p7 = (a[p][r+1] - a[p+1][r+1]) * (b[p+1][r] + b[p+1][r+1]);
        c[p][r] = p1 + p4 - p5 + p7;
        c[p][r+1] = p3 + p5;
        c[p+1][r] = p2 + p4;
        c[p+1][r+1] = p1 + p3 - p2 + p6;
    }
    else
    {
        if (q/2-p >= 2 && s/2-r >= 2) strassen(p, q/2, r, s/2);
        if (q/2-p >= 2 && s-s/2 >= 2) strassen(p, q/2, s/2, s);
        if (q-q/2 >= 2 && s/2-r >= 2) strassen(q/2, q, r, s/2);
        if (q-q/2 >= 2 && s-s/2 >= 2) strassen(q/2, q, s/2, s);
    }
}

But an easier recursion stopper would be at the beginning of the function, like:

{
    int p1,p2,p3,p4,p5,p6,p7;   
    if(q-p < 2 || s-r < 2) return;
    if(q-p == 2 && s-r == 2)
    { ...
于 2013-03-16T16:48:45.607 回答