卷积扩展知识
分治卷积
问题
已知$g(i)$的各项函数值
$f(i)=\sum_{j=1}^i g(j)*f(i-j)$
求$f(i)$的各项函数值
解法
考虑cdq分治思想
每次二分,先把左边的f(i)计算出来, 然后计算左边的f(i)对右边的贡献,再继续累积右边的贡献
二分到达边界时,表明这个点的函数值已经统计完毕
同理,当二分完一个区间时,表明该区间所有函数值已计算完毕
举例:
假设一开始知道f(0)的值
二分到区间0~1时,左边区间0~0已知,那么可以用f(0)计算f(1),另外f(1)除了f(0)无其他贡献来源,所以f(1)计算完毕
(绿色表示计算完成,黄色表示正在计算中)
回退到0~2时,0~1已知,可以用于计算f(1)~f(2)
进入2~2,到达边界,f(2)计算完成,回退,累计f(2)对f(3)的贡献
进入3~3,到达边界,f(3)计算完成,回退至0~7区间,累计f(0~3)对f(4~7)的贡献
之后以此类推即可
代码
代码中有些细节解释
#includeusing namespace std; #define N 300000 #define int long long int g[N],f[N],res[N],ind,rev[N],ta[N],tb[N]; const int p=998244353; int qpow(int aa,int bb) { int res=1; aa%=p; while(bb) { if(bb&1) res*=aa,res%=p; aa*=aa,aa%=p; bb>>=1ll; } return res; } void ntt(int arr[],int g,int n) { for(int i=1;i<=n;i++) { if(i >1]>>1)|((i&1)<<(x-1)); ntt(ta,3,len); ntt(tb,3,len); for(int i=0;i<=len;i++) ans[i]=ta[i]*tb[i]%p; int inv=qpow(3,p-2); ntt(ans,inv,len); //ntt(a,inv,len,p); //ntt(b,inv,len,p); int z=qpow(len,p-2); for(int i=0;i<=len;i++) ans[i]=ans[i]*z%p,ta[i]=tb[i]=0; } void divide(int l,int r) { if(l==r) return; int mid=(l+r)/2; divide(l,mid); memset(res,0,16*(r-l+1)); memcpy(ta,f+l,8*(mid-l+1)); memcpy(tb,g,8*(r-l+1));//实际是f(l~mid)*g(mid+1~r) 但为了凑足g的次数还是从g(1)开始 mul(res,r-l+1);//乘出来的res应该是r-l+1+mid-l+1项的,但我们只关心mid+1~r项,所以只需要计算1~r-l+1项就行了 for(int i=mid+1;i<=r;i++) f[i]+=res[i-l],f[i]%=p; divide(mid+1,r); } signed main() { int n; cin>>n; n--; for(int i=1;i<=n;i++) scanf("%lld",&g[i]); f[0]=1; int t=1; while(t 任意模数卷积
如果题目的模数不是NTT模数,甚至没有模数,并且值域范围很大,fft会掉精度
介绍两种办法
拆系数fft
将多项式系数拆为$a_i=b_i*m+c_I$,m是阈值,一般取1e5,这样如果$a_i<=10^9,则b_i,c_i<=10^5$,乘起来不会太大
这样$f(x)=f_1(x)*m+f_2(x)$
然后$f(x)*g(x)=f_1(x)*g_1(x)*m^2+(f_1(x)*g_2(x)+f_2(x)*g_1(x))*m+f_2(x)*g_2(x)$
做四次fft即可
三模数ntt
代码
#includeusing namespace std; #define N 300000 #define int long long int ta[N],tb[N],a[N],b[N],ans[5][N],p[4]={0,469762049,998244353,1004535809},rev[N]; int fmul(int a, int b, int mod) {//用于计算会爆long long的乘法 a %= mod, b %= mod; return ((a * b - (int)((int)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod; } int qpow(int aa,int bb,int pp) { int res=1; aa%=pp; while(bb) { if(bb&1) res*=aa,res%=pp; aa*=aa,aa%=pp; bb>>=1ll; } return res; } void ntt(int arr[],int g,int n,int p) { for(int i=1;i<=n;i++) { if(i >1]>>1)|((i&1)<<(l-1)); } ntt(a,3,len,p); ntt(b,3,len,p); for(int i=0;i<=len;i++) ans[i]=a[i]*b[i]%p; int inv=qpow(3,p-2,p); ntt(ans,inv,len,p); //ntt(a,inv,len,p); //ntt(b,inv,len,p); for(int i=0;i<=len;i++) ans[i]=ans[i]*qpow(len,p-2,p)%p; } signed main() { int n,m,p0; cin>>n>>m>>p0; while(len