[CF960G] Bandit Blues
题意
给你三个正整数 \(n,a,b\),定义 \(A\) 为一个排列中是前缀最大值的数的个数,定义 \(B\) 为一个排列中是后缀最大值的数的个数,求长度为 \(n\) 的排列中满足 \(A = a\) 且 \(B = b\) 的排列个数。\(n \le 10^5\),答案对 \(998244353\) 取模。
Sol
首先可以设一个 \(DP\) 状态 \(f(i,j)\) 表示,长度为 \(i\) 的排列,有 \(j\) 个前缀最大值的方案数。
那么转移就是枚举新放一个最小值,只有放在序列开头才有 \(1\) 的贡献:
\[f(i,j)=f(i-1,j-1)+(i-1)\times f(i-1,j) \]最后的答案就是枚举最大值 \(n\) 放在位置 \(i\),然后左边长度为 \(i-1\) 且有 \(a-1\) 个前缀最大值,右边长度为 \(n-1-i\) 且有 \(b-1\) 个后缀最大值,可以发现这个后缀最大值和前缀最大值的方案是相等的,那么最终的答案就是:
\[ans=\sum_{i=1}^n C(n-1,i-1)\cdot f(i-1,a-1)\cdot f(n-i-1,b-1) \]稍微熟练一点就可以看出,这个 \(f(i,j)\) 本质上就是第一类斯特林数,即 \(i\) 个数放 \(j\) 个圆排列的方案数。
所以这个式子就可以化简了,从组合意义上理解就是,从 \(n-1\) 个数拿出来形成 \(a+b-2\) 个圆排列,其中把 \(a-1\) 个放在 \(n\) 前面的方案数。最后答案就变成了:
\[ans=s(n-1,a+b-2)\times C(a+b-2,a-1) \]只需要预处理第一类斯特林数一行就行。分治\(\mathrm{NTT}\)复杂度\(O(n\log^2 n)\),倍增复杂度\(O(n\log n)\)。具体见。
Code
两种都写了下
// 分治NTT
#pragma GCC optimize(2)
#include
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define vec std::vector
#define pii std::pair
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e5+5;
const int mod=998244353;
int n,A,B,lim;
int a[N],b[N],rev[N];
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
void ntt(int *f,int g){
for(int i=1;i3)
for(int in=ksm(lim),i=0;i>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i>1; return mul(solve(l,mid),solve(mid+1,r),r-l+1);
}
int S(int n,int m){
if(!n) return 1;
if(n
// 倍增
#pragma GCC optimize(2)
#include
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e5+5;
const int mod=998244353;
int lim,rev[N],pw[N];
int a[N],b[N],c[N],d[N];
int n,A,B,fac[N],ifac[N];
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
void ntt(int *f,int g){
for(int i=1;i3)
for(int in=ksm(lim),i=0;i>1); int mid=len>>1;
pw[0]=1;for(int i=1;i<=mid;i++) pw[i]=1ll*pw[i-1]*mid%mod;
lim=1;while(lim<=len) lim<<=1;
for(int i=1;i>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i<=mid;i++) c[i]=1ll*a[i]*fac[i]%mod,d[mid-i]=1ll*pw[i]*ifac[i]%mod;
ntt(c,3),ntt(d,3);
for(int i=0;i