我最近也想实现NTT来实现快速乘法而不是DFFT。读了很多令人困惑的东西,到处都是不同的字母,没有简单的解决方案,而且我的有限域知识也生疏了,但是今天我终于弄对了(经过两天的尝试和模拟DFT系数)所以这是我的见解对于NTT:
计算
X(i) = sum(j=0..n-1) of ( Wn^(i*j)*x(i) );
哪里X[]
是NTT变换x[]
的大小n
,哪里Wn
是NTT基。所有计算都是在整数模算术上进行的,mod p
任何地方都没有复数。
重要价值观
Wn = r ^ L mod p
是NTT
Wn = r ^ (p-1-L) mod p
的基础 是INTT的基础 是INTT
Rn = n ^ (p-2) mod p
的缩放乘法常数是素数,并且是NTT的x[i]或INTT的X[i]的最大值,并且除法必须组合,所以 如果或必须组合, 如果是子结果最大值取决于计算类型和类型。对于单个(I)NTT,它是但对于两个大小的向量的卷积,它是等等。有关它的更多信息,请参阅在有限域上实现 FFT 。 ~(1/n)
p
p mod n == 1
p>max'
max
r = <1,p)
L = <1,p)
p-1
r,L
r^(L*i) mod p == 1
i=0
i=n
r,L
r^(L*i) mod p != 1
0 < i < n
max'
n
max' = n*max
n
max' = n*max*max
不同的功能组合r,L,p
不同n
这很重要,您必须在每个 NTT 层之前重新计算或从表中选择参数(n
始终是前一个递归的一半)。
这是我找到参数的C++r,L,p
代码(需要不包括在内的模运算,您可以将其替换为 (a+b)%c,(ab)%c,(a*b)%c,... 但在这种情况下要注意溢出,特别是modpow
)modmul
代码尚未优化,但有一些方法可以显着加快速度。素数表也相当有限,因此要么使用SoE 或任何其他算法来获得素数,max'
以便安全工作。
DWORD _arithmetics_primes[]=
{
2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,
179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,
419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,
661,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883,887,907,911,919,929,937,941,
947,953,967,971,977,983,991,997,1009,1013,1019,1021,1031,1033,1039,1049,1051,1061,1063,1069,1087,1091,1093,1097,1103,1109,1117,1123,1129,1151,
0}; // end of table is 0, the more primes are there the bigger numbers and n can be used
// compute NTT consts W=r^L%p for n
int i,j,k,n=16;
long w,W,iW,p,r,L,l,e;
long max=81*n; // edit1: max num for NTT for my multiplication purposses
for (e=1,j=0;e;j++) // find prime p that p%n=1 AND p>max ... 9*9=81
{
p=_arithmetics_primes[j];
if (!p) break;
if ((p>max)&&(p%n==1))
for (r=2;r<p;r++) // check all r
{
for (l=1;l<p;l++)// all l that divide p-1
{
L=(p-1);
if (L%l!=0) continue;
L/=l;
W=modpow(r,L,p);
e=0;
for (w=1,i=0;i<=n;i++,w=modmul(w,W,p))
{
if ((i==0) &&(w!=1)) { e=1; break; }
if ((i==n) &&(w!=1)) { e=1; break; }
if ((i>0)&&(i<n)&&(w==1)) { e=1; break; }
}
if (!e) break;
}
if (!e) break;
}
}
if (e) { error; } // error no combination r,l,p for n found
W=modpow(r, L,p); // Wn for NTT
iW=modpow(r,p-1-L,p); // Wn for INTT
这是我的慢 NTT 和 INTT 实现(我还没有快速 NTT,INTT),它们都成功地用 Schönhage-Strassen 乘法进行了测试。
//---------------------------------------------------------------------------
void NTT(long *dst,long *src,long n,long m,long w)
{
long i,j,wj,wi,a,n2=n>>1;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i],m),m);
wi=modmul(wi,wj,m);
}
dst[j]=a;
wj=modmul(wj,w,m);
}
}
//---------------------------------------------------------------------------
void INTT(long *dst,long *src,long n,long m,long w)
{
long i,j,wi=1,wj=1,rN,a,n2=n>>1;
rN=modpow(n,m-2,m);
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i],m),m);
wi=modmul(wi,wj,m);
}
dst[j]=modmul(a,rN,m);
wj=modmul(wj,w,m);
}
}
//---------------------------------------------------------------------------
dst
是目标数组
src
是源数组
n
是数组大小
m
是模数 ( p
)
w
是基数 ( Wn
)
希望这对某人有所帮助。如果我忘记了什么请写...
[edit1:快速 NTT/INTT]
最后我设法让NTT/INTT快速工作。比普通的FFT有点棘手:
//---------------------------------------------------------------------------
void _NFTT(long *dst,long *src,long n,long m,long w)
{
if (n<=1) { if (n==1) dst[0]=src[0]; return; }
long i,j,a0,a1,n2=n>>1,w2=modmul(w,w,m);
// reorder even,odd
for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
// recursion
_NFTT(src ,dst ,n2,m,w2); // even
_NFTT(src+n2,dst+n2,n2,m,w2); // odd
// restore results
for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w,m))
{
a0=src[i];
a1=modmul(src[j],w2,m);
dst[i]=modadd(a0,a1,m);
dst[j]=modsub(a0,a1,m);
}
}
//---------------------------------------------------------------------------
void _INFTT(long *dst,long *src,long n,long m,long w)
{
long i,rN;
rN=modpow(n,m-2,m);
_NFTT(dst,src,n,m,w);
for (i=0;i<n;i++) dst[i]=modmul(dst[i],rN,m);
}
//---------------------------------------------------------------------------
[编辑3]
我已经优化了我的代码(比上面的代码快 3 倍),但我仍然对它不满意,所以我开始用它提出新问题。在那里,我进一步优化了我的代码(比上面的代码快大约 40 倍),因此它的速度几乎与FFT在相同位大小的浮点上的速度相同。它的链接在这里: