嗯,是《算法导论》中的一道题,编号为4.2-6
。它是这样描述的:
a kn*n matrix by an n*kn matrix
使用Strassen's algorithm
as时,您能以多快的速度繁殖subroutine
?
我正在考虑将两个矩阵都扩展到kn*kn matrix
,然后我可以将 Strassen 的算法应用于这个问题。但我会得到一个Math.pow(kn, lg7) running time
.
有没有人有更好的解决方案。祝大家新年快乐。
嗯,是《算法导论》中的一道题,编号为4.2-6
。它是这样描述的:
a kn*n matrix by an n*kn matrix
使用Strassen's algorithm
as时,您能以多快的速度繁殖subroutine
?
我正在考虑将两个矩阵都扩展到kn*kn matrix
,然后我可以将 Strassen 的算法应用于这个问题。但我会得到一个Math.pow(kn, lg7) running time
.
有没有人有更好的解决方案。祝大家新年快乐。
想想而不是将 ak*1 向量乘以 1*k 向量。这需要 k^2 次乘法,最后得到 ak*k 矩阵。这里唯一不同的是你的向量的元素是 n*n 矩阵,所以如果你使用 Strassen 算法乘以 n*n,你最终会做 O(k^2 n^(log 7)) 标量乘法矩阵。
Strasens 算法的另一个基于向量的实现在这里,它显示了 naive 和 strssens 运行时间的比较:
enter code here:
#include <cstdio>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cassert>
#include <vector>
#include <ctime>
using namespace std;
void fun(vector<vector<int> >& u , vector<vector<int> >&m , int P , int n)
{
for(int i = 0 ; i < n ; i++)
{
vector<int>t ;
for(int j = 0 ; j < n ; j++)
{
switch(P)
{
case 1:
{
t.push_back(u[i][j]);
break;
}
case 2:
{
t.push_back(u[i][j+n]);
break;
}
case 3:
{
t.push_back(u[i+n][j]);
break;
}
case 4:
{
t.push_back(u[i+n][j+n]);
break;
}
}
}
m[i] = t;
}
}
void normalmul(int n , vector< vector<int> >& u , vector< vector<int> >& v , vector< vector<int> >& z )
{
for(int i = 0 ; i < n ; i++)
{
for(int j = 0 ; j < n ; j++)
{
z[i][j] = 0;
for(int k = 0 ; k < n ; k++)
{
z[i][j] += (u[i][k] * v[k][j]);
}
}
}
}
void strassen(int n , vector< vector<int> >& u , vector< vector<int> >& v , vector< vector<int> >& z)
{
if(n == 32)
{
normalmul(n,u,v,z);
return;
}
else
{
int Shiftt = n>>1;
vector<vector<int> >AA(Shiftt , vector<int>(Shiftt));
vector<vector<int> >BB(Shiftt , vector<int>(Shiftt));
vector<vector<int> >CC(Shiftt , vector<int>(Shiftt));
vector<vector<int> >DD(Shiftt , vector<int>(Shiftt));
vector<vector<int> >EE(Shiftt , vector<int>(Shiftt));
vector<vector<int> >FF(Shiftt , vector<int>(Shiftt));
vector<vector<int> >GG(Shiftt , vector<int>(Shiftt));
vector<vector<int> >HH(Shiftt , vector<int>(Shiftt));
vector<vector<int> >A1(Shiftt , vector<int>(Shiftt));
vector<vector<int> >A2(Shiftt , vector<int>(Shiftt));
vector<vector<int> >A3(Shiftt , vector<int>(Shiftt));
vector<vector<int> >A4(Shiftt , vector<int>(Shiftt));
fun(u,AA,1,n>>1);
fun(u,BB,2,n>>1);
fun(u,CC,3,n>>1);
fun(u,DD,4,n>>1);
fun(v,EE,1,n>>1);
fun(v,FF,2,n>>1);
fun(v,GG,3,n>>1);
fun(v,HH,4,n>>1);
vector<vector<int> >M1(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M2(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M3(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M4(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M5(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M6(Shiftt , vector<int>(Shiftt));
vector<vector<int> >M7(Shiftt , vector<int>(Shiftt));
vector<vector<int> >T1(Shiftt , vector<int>(Shiftt));
vector<vector<int> >T2(Shiftt , vector<int>(Shiftt));
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = AA[i][j] + DD[i][j];
T2[i][j] = EE[i][j] + HH[i][j];
}
}
strassen(Shiftt,T1,T2,M1);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = CC[i][j] - AA[i][j];
T2[i][j] = EE[i][j] + FF[i][j];
}
}
strassen(Shiftt,T1,T2,M6);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = BB[i][j] - DD[i][j];
T2[i][j] = GG[i][j] + HH[i][j];
}
}
strassen(Shiftt,T1,T2,M7);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = CC[i][j] + DD[i][j];
T2[i][j] = EE[i][j] ;
}
}
strassen(Shiftt,T1,T2,M2);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = AA[i][j] ;
T2[i][j] = FF[i][j] - HH[i][j];
}
}
strassen(Shiftt,T1,T2,M3);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = DD[i][j];
T2[i][j] = GG[i][j] - EE[i][j];
}
}
strassen(Shiftt,T1,T2,M4);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
T1[i][j] = AA[i][j] + BB[i][j];
T2[i][j] = HH[i][j];
}
}
strassen(Shiftt,T1,T2,M5);
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
A1[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j] ;
A2[i][j] = M3[i][j] + M5[i][j] ;
A3[i][j] = M2[i][j] + M4[i][j] ;
A4[i][j] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j] ;
}
}
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
z[i][j] = A1[i][j];
}
}
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
z[i][j+Shiftt] = A2[i][j];
}
}
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
z[i+Shiftt][j] = A3[i][j];
}
}
for(int i = 0 ; i < Shiftt ; i++)
{
for(int j = 0 ; j < Shiftt ; j++)
{
z[i+Shiftt][j+Shiftt] = A4[i][j];
}
}
}
}
int main()
{
int t,n;
freopen("input_file.txt","r",stdin);
cin >> t;
while(t--)
{
int vl ;
scanf("%d",&n);
cout << "value of n " << n << endl ;;
vector< vector<int> >u(n,vector<int>(n));
vector< vector<int> >v(n,vector<int>(n));
vector< vector<int> >z(n,vector<int>(n));
vector< vector<int> >zz(n,vector<int>(n));
vector<int> temp;
for(int i = 0 ; i < n ; i++)
{
vector<int> temp;
for(int j = 0 ; j < n ; j++)
{
scanf("%d",&vl);
temp.push_back(vl);
}
u[i] = temp;
}
for(int i = 0 ; i < n ; i++)
{
vector<int> temp;
for(int j = 0 ; j < n ; j++)
{
scanf("%d",&vl);
temp.push_back(vl);
}
v[i] = temp;
}
clock_t start , end ;
//USING NAIVE APPROACH
start = clock();
cout<<"Traditional Algorithm Running Time : ";
normalmul(n,u,v,z);
end = clock() ;
cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ;
/*cout << "ANSWER OF MULTIPLICATION BY NAIVE APPROACH" << endl ;
for(int i = 0 ; i < n ; i++)
{
for(int j = 0 ; j < n ; j++)
{
cout << z[i][j] << " ";
}
cout << endl ;
}*/
//USING STRASSENS ALGORITHM
start = clock() ;
strassen(n,u,v,zz);
end = clock();
cout<<"Strassen Algorithm Running Time : ";
cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ;
/*cout << "ANSWER BY STRASSENS ALGORITHM " << endl ;
for(int i = 0 ; i < n ; i++)
{
for(int j = 0 ; j < n ; j++)
{
cout << zz[i][j] << " ";
}
cout << endl ;
}*/
}
return 0;
*/ IPG_2011006 Abhishek Yadav */
}
您可以在 C++ 中看到Strassen上的实现,该算法在 Wikipedia 中也有很好的描述。