2

我只是在自学算法和数据结构,我想知道是否有人有 Strassen 矩阵乘法算法的 C#(或 C++)实现?

我只是想运行它,看看它做了什么,并更多地了解它是如何工作的。

4

2 回答 2

3

免责声明:我没有尝试过任何这些,但它们似乎是 OP 正在寻找的。这些链接只是通过查看一些 Google 代码搜索结果。

我找到了一个C# 版本。该项目没有任何多余的装饰;这只是源头。但是,它似乎只是从我的第一次粗略扫描开始执行算法。特别是,您将要查看此文件

对于 C++,我在这个 google code project中找到了一些代码。代码当然是英文的,但 wiki 都是西里尔文(俄文?)。您将主要看这个文件。它似乎同时具有 Strassen 算法的串行和并行版本。

这些项目可能并不完全正确,但它们是您可能想要更仔细查看的东西。

于 2010-04-28T23:26:44.627 回答
1
// Recursive matrix mult by strassen's method.
// 2013-Feb-15 Fri 11:47 by moshahmed/at/gmail.

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#define M 2
#define N (1<<M)

typedef double datatype;
#define DATATYPE_FORMAT "%4.2g"
typedef datatype mat[N][N]; // mat[2**M,2**M]  for divide and conquer mult.
typedef struct { int ra, rb, ca, cb; } corners; // for tracking rows and columns.
// A[ra..rb][ca..cb] .. the 4 corners of a matrix.

// set A[a] = I
void identity(mat A, corners a){
  int i,j;
  for(i=a.ra;i<a.rb;i++)
    for(j=a.ca;j<a.cb;j++)
      A[i][j] = (datatype) (i==j);
}

// set A[a] = k
void set(mat A, corners a, datatype k){
  int i,j;
  for(i=a.ra;i<a.rb;i++)
    for(j=a.ca;j<a.cb;j++)
      A[i][j] = k;
}

// set A[a] = [random(l..h)].
void randk(mat A, corners a, double l, double h){
  int i,j;
  for(i=a.ra;i<a.rb;i++)
    for(j=a.ca;j<a.cb;j++)
      A[i][j] = (datatype) (l + (h-l) * (rand()/(double)RAND_MAX));
}

// Print A[a]
void print(mat A, corners a, char *name) {
  int i,j;
  printf("%s = {\n",name);
  for(i=a.ra;i<a.rb;i++){
    for(j=a.ca;j<a.cb;j++)
      printf(DATATYPE_FORMAT ", ", A[i][j]);
    printf("\n");
  }
  printf("}\n");
}

// C[c] = A[a] + B[b]
void add(mat A, mat B, mat C, corners a, corners b, corners c) {
  int rd = a.rb - a.ra;
  int cd = a.cb - a.ca;
  int i,j;
  for(i = 0; i<rd; i++ ){
    for(j = 0; j<cd; j++ ){
      C[i+c.ra][j+c.ca] = A[i+a.ra][j+a.ca] + B[i+b.ra][j+b.ca];
    }
  }
}

// C[c] = A[a] - B[b]
void  sub(mat A, mat B, mat C, corners a, corners b, corners c) {
  int rd = a.rb - a.ra;
  int cd = a.cb - a.ca;
  int i,j;
  for(i = 0; i<rd; i++ ){
    for(j = 0; j<cd; j++ ){
      C[i+c.ra][j+c.ca] = A[i+a.ra][j+a.ca] - B[i+b.ra][j+b.ca];
    }
  }
}

// Return 1/4 of the matrix: top/bottom , left/right.
void find_corner(corners a, int i, int j, corners *b) {
  int rm = a.ra + (a.rb - a.ra)/2 ;
  int cm = a.ca + (a.cb - a.ca)/2 ;
  *b = a;
  if (i==0)  b->rb = rm;     // top rows
  else       b->ra = rm;     // bot rows
  if (j==0)  b->cb = cm;     // left cols
  else       b->ca = cm;     // right cols
}

// Multiply: A[a] * B[b] => C[c], recursively.
void mul(mat A, mat B, mat C, corners a, corners b, corners c) {
  corners aii[2][2], bii[2][2], cii[2][2], p;
  mat P[7], S, T;
  int i, j, m, n, k;

  // Check: A[m n] * B[n k] = C[m k]
  m = a.rb - a.ra; assert(m==(c.rb-c.ra));
  n = a.cb - a.ca; assert(n==(b.rb-b.ra));
  k = b.cb - b.ca; assert(k==(c.cb-c.ca));
  assert(m>0);

  if (n==1) {
    C[c.ra][c.ca] += A[a.ra][a.ca] * B[b.ra][b.ca];
    return;
  }

  // Create the 12 smaller matrix indexes:
  //  A00 A01   B00 B01   C00 C01
  //  A10 A11   B10 B11   C10 C11
  for(i=0;i<2;i++) {
  for(j=0;j<2;j++) {
        find_corner(a, i, j, &aii[i][j]);
        find_corner(b, i, j, &bii[i][j]);
        find_corner(c, i, j, &cii[i][j]);
      }
  }

  p.ra = p.ca = 0;
  p.rb = p.cb = m/2;

  #define LEN(A) (sizeof(A)/sizeof(A[0]))
  for(i=0; i < LEN(P); i++) set(P[i], p, 0);

  #define ST0 set(S,p,0); set(T,p,0)

  // (A00 + A11) * (B00+B11) = S * T = P0
  ST0;
  add( A, A, S, aii[0][0], aii[1][1], p);
  add( B, B, T, bii[0][0], bii[1][1], p);
  mul( S, T, P[0], p, p, p);

  // (A10 + A11) * B00 = S * B00 = P1
  ST0;
  add( A, A, S, aii[1][0], aii[1][1], p);
  mul( S, B, P[1], p, bii[0][0], p);

  // A00 * (B01 - B11) = A00 * T = P2
  ST0;
  sub( B, B, T, bii[0][1], bii[1][1], p);
  mul( A, T, P[2], aii[0][0], p, p);

  // A11 * (B10 - B00) = A11 * T = P3
  ST0;
  sub(B, B, T, bii[1][0], bii[0][0], p);
  mul(A, T, P[3], aii[1][1], p, p);

  // (A00 + A01) * B11 = S * B11 = P4
  ST0;
  add(A, A, S, aii[0][0], aii[0][1], p);
  mul(S, B, P[4], p, bii[1][1], p);

  // (A10 - A00) * (B00 + B01) = S * T = P5
  ST0;
  sub(A, A, S, aii[1][0], aii[0][0], p);
  add(B, B, T, bii[0][0], bii[0][1], p);
  mul(S, T, P[5], p, p, p);

  // (A01 - A11) * (B10 + B11) = S * T = P6
  ST0;
  sub(A, A, S, aii[0][1], aii[1][1], p);
  add(B, B, T, bii[1][0], bii[1][1], p);
  mul(S, T, P[6], p, p, p);

  // P0 + P3 - P4 + P6 = S - P4 + P6 = T + P6 = C00
  add(P[0], P[3], S, p, p, p);
  sub(S, P[4], T, p, p, p);
  add(T, P[6], C, p, p, cii[0][0]);

  // P2 + P4 = C01
  add(P[2], P[4], C, p, p, cii[0][1]);

  // P1 + P3 = C10
  add(P[1], P[3], C, p, p, cii[1][0]);

  // P0 + P2 - P1 + P5 = S - P1 + P5 = T + P5 = C11
  add(P[0], P[2], S, p, p, p);
  sub(S, P[1], T, p, p, p);
  add(T, P[5], C, p, p, cii[1][1]);

}
int main() {
  mat A, B, C;
  corners ai = {0,N,0,N};
  corners bi = {0,N,0,N};
  corners ci = {0,N,0,N};
  srand(time(0));
  // identity(A,bi); identity(B,bi);
  // set(A,ai,2); set(B,bi,2);
  randk(A,ai, 0, 2); randk(B,bi, 0, 2);
  print(A, ai, "A"); print(B, bi, "B");
  set(C,ci,0);
  // add(A,B,C, ai, bi, ci);
  mul(A,B,C, ai, bi, ci);
  print(C, ci, "C");
  return 0;
}  
于 2013-02-15T09:18:26.310 回答