0

我正在尝试使用 Karatsuba 的方法为多项式乘法实现一个简单的分而治之算法,即使用该方法进行p=a+b*x^k, , q=c+d*x^kwith和递归计算, 其中某处接近or度数的一半。p*q=ac+(ad+bc)x^k+bd*x^(2k)ad+bc=ac+bd-(a-b)*(c-d)ac,bd,(a-b)(c-d)kpq

以下代码在用于对度数 >= 64 左右的多项式(随机生成的介于 0 和 9 之间的整数系数)求平方时崩溃而没有错误消息,但似乎适用于较小的度数。

补充:(程序没有正确终止。另外,对于 64 级,应该只使用大约 3^(log(64))=729 递归函数调用,所以我认为如果代码有效,可以排除堆栈溢出正如在这方面的意图。)

单独使用的蛮力功能似乎在很大程度上也可以正常工作。

struct poly {
 int deg;
 double* coeff;
};

poly stdpolmult(poly p,poly q) { // standard algorithm 
 poly r;

 r.deg= p.deg+q.deg;
 r.coeff = (double*) calloc (r.deg,sizeof(double));
 int i,j;

 for (i=0;i<=p.deg;i++)
  for (j=0;j<=q.deg;j++)
   r.coeff[i+j]=r.coeff[i+j]+p.coeff[i]*q.coeff[j];

 return r;
}

poly fastpolmult(poly p,poly q) {  // Divide & Conquer
 if ((p.deg<=7)&&(q.deg<=7))
  return stdpolmult(p,q); // brute force

 poly a,b,c,d,u,v,x,y,w,z,s,r;
 int k=p.deg/2;
 if (p.deg<q.deg)
  k=q.deg/2;
 a=polfirstpart(p,k);
 b=pollastpart(p,k);
 c=polfirstpart(q,k);
 d=pollastpart(q,k); /* let p=p_1+x^k*p_2, q=q_1+x^k+q_2, then
                       a= p_1,b=p_2,c=q_1,d=q_2 */

 u = fastpolmult(a,c); // u =p_1*p_2
 v = fastpolmult(b,d); // v =q_1*q_2
 polneg(b); // b= -p_2
 polneg(d); // d= -q_2
 x=poladd(a,b); // x=p_1-p_2
 y=poladd(c,d); // y=q_1-q_2
 w=fastpolmult(x,y); // w=(p_1-p_2)*(q_1-q_2)
 polneg(w); // w= -(p_1-p_2)*(q_1-q_2)
 z=poladd(u,v); // z=p_1*p_2+q_1*q_2
 s=poladd(z,w); // s=p_1*p_2+q_1*q_2-(p_1-p_2)*(q_1-q_2) = p_1*q_2+p_2*q_1

 polfree(z); polfree(w); polfree(x); polfree(y);

 x=polshift(s,k); // x=(p_1*q_2+p_2*q_1)*x^k
 y=polshift(v,2*k); // y=q_1*q_2*x^(2k)

 z=poladd(u,x); // z=p_1*p_2+(p_1*q_2+p_2*q_1)*x^k
 r=poladd(z,y); // r = p_1*p_2+(p_1*q_2+p_2*q_1)*x^k +q_1*q_2*x^(2k) = p*q

 polfree(x);polfree(y);polfree(z);

 return r;
 }

如果需要,我可以为使用的函数(polfirstpart、pollastpart、polneg、poladd、polshift、polfree)添加我的代码。他们没什么特别的。(我单独测试了它们,它们似乎有效)。

补充:大多数这些功能的代码:

poly poladd(poly p,poly q) {
 int i,n;
 poly r;
 if (p.deg>=q.deg)
  r.deg=p.deg;
 else
  r.deg=q.deg;
 r.coeff = (double*) calloc (r.deg+1,sizeof(double));
 if (p.deg<=q.deg)
  n=p.deg;
 else
  n=q.deg;

 for (i=0;i<=n;i++)
  r.coeff[i]=p.coeff[i]+q.coeff[i];

 if (p.deg>q.deg)
  for (i=n+1;i<=p.deg;i++)
   r.coeff[i]=p.coeff[i];
 else
  for (i=n+1;i<=q.deg;i++)
   r.coeff[i]=q.coeff[i];

 return r;
}

poly polfirstpart(poly p, int k) { /* if p=a+x^k*b, take a */
 poly r;
 if (k<=0)
  return zpol; // zero pol
 if (k>p.deg)
  return p;
 r.coeff=p.coeff;
 r.deg=k-1;
 return r;
}

 poly pollastpart(poly p,int k) { /* if p=a+x^k*b, take b */
  poly r;
  if (k<0)
   return zpol;
  if (k>p.deg)
   return zpol;

  r.coeff=(p.coeff)+k;
  r.deg=p.deg-k;
  return r;
 }

poly polshift(poly p,int k) {  /* x^k*p */
 int i;
 poly r;
 r.deg=p.deg+k;
 r.coeff= (double*) calloc(r.deg+1,sizeof(double));
 for (i=0;i<=p.deg;i++)
  r.coeff[k+i]=p.coeff[i];
 return r;
}

void genzpol() { // called in main
 zpol.deg=0;
 zpol.coeff=(double*) calloc(1,sizeof(double));
}
4

0 回答 0