4

我从某个地方复制了 strassen 的算法,然后执行了它。这是输出

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

其中strassen1是一种动态方法,strassen2用于缓存,classical是旧的矩阵乘法。这意味着我们古老而简单的经典是最好的。这是真的还是我在某个地方错了?这是Java中的代码。

import java.util.Random;

class TestIntMatrixMultiplication {

    public static void main (String...args) throws Exception {
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            }
        }



        System.out.println("n = " + n);

        if (a.length < 64) {
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        }

        for (int i = 0; i <3; ++i) {
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        }
    }

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    }
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    }
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    }

    static void dumpMatrix (int[][] m) {
        for (int[] row : m) {
            System.out.print("[\t");
            for (int val : row) {
                System.out.print(val);
                System.out.print('\t');
            }
            System.out.println(']');
        }
    }
}

class strassen1{

    public String getName () {
        return "Strassen(dynamic)";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        return strassenMatrixMultiplication(a, b);
    }

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        }

        return result;
    }

    public static int [][] addMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    }

    public static int [][] subtractMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    }

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    }

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    }
}
class strassen2{

    public String getName () {
        return "Strassen(cached)";
    }

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        final int n = c.length;

        if (p1 == null || p1.length < n) {
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        }

        mult(c, a, b, 0, 0, n, 0);

        return c;
    }

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
        if(n == 1) {
            c[i0][j0] = a[i0][j0] * b[i0][j0];
        } else {
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    }
            }

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                }
            }

            mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                }
            }

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                }
            }

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                }
            }

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                }
            }
        }
    }

    void dumpInternal () {
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    }
}


class Classical{
    public String getName () {
        return "classic";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        int n = a.length;

        for(int i=0; i<n; i++) {
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) {
                int sum = 0;

                for(int k=0; k<n; k++) {
                    sum += a_i[k] * b[k][j];
                }

                c_i[j] = sum;
            }
        }

        return c;
    }
}
4

3 回答 3

5

我看到的问题:

1)您的 Strassen 乘法一直在动态分配内存。这会扼杀性能。

2)您的 Strassen 乘法应该切换到小尺寸的常规乘法,而不是一直递归(尽管这种优化会使您的测试无效)。

3)您的矩阵大小可能太小而看不到差异。

您应该对几种不同的尺寸进行比较。大概是256、512、1024、2048、4096、8192……然后画出时代,看趋势。如果它是 2 的所有幂,您可能会希望矩阵大小在对数尺度上。

Strassen 仅对大 N 更快。多大将在很大程度上取决于实现。你为经典所做的只是一个基本的实现,在现代机器上也不是最优的。

于 2011-06-06T14:21:26.013 回答
2

除了实施问题,我认为您误解了算法的性能。就像 phkahler 所说,您对算法性能的期望有点偏离。分治算法适用于大输入,因为它们递归地将问题分解为可以更快解决的子问题。

但是,与此拆分操作相关的开销可能会导致算法对于小型甚至中型输入的运行速度(有时会慢得多)。通常,像 Strassen 这样的算法的理论分析将包括所谓的“断点”计算。这是输入大小,其中拆分的开销比简单的技术更可取。

您的代码需要包含对在断点处切换到简单技术的输入大小的检查。

于 2011-06-06T14:59:17.397 回答
1

写下 Strassen 算法对 2 x 2 矩阵的作用。计算操作。这个数字绝对是荒谬的。将 Strassen 方法用于 2x2 矩阵是愚蠢的。对于 3 x 3 或 4 x 4 的矩阵也是如此,而且可能还有很大的提升空间。

于 2014-04-03T17:25:39.583 回答