AGC019E - Shuffle and Swap 题解


题目链接:E - Shuffle and Swap

题目大意:洛谷


题解:这一道题有两个做法。

Solution 1

考虑有 \(n\) 个位置上 \(A,B\) 均为 1 ,设为位置 A ,有 \(2m\) 个位置上 \(A\)1 , 或 \(B\)1 ,设为位置 B(不包含位置 A),那么我们的目标就是让最后位置 B 的个数减少为 0,那么我们每一次交换位置 A 中的两个数并不会是位置 B 的个数发生变化,所以不管它,如果我们交换位置 B 中的两个数则会使位置 B 的个数减 1,这种这种情况的方案数是 \(m^2\) ,如果我们交换一个位置 A 上的数和一个位置 B 上的数,则会使位置 A 的个数减 1 ,位置 B 的个数不变。

因此,我们的转移方程就是 \(f_{n,m}= f_{n,m-1}\times m^2+f_{n-1,m}\times n\times m\)

最后我们考虑位置 A 还剩余的方案数,\(ans =\sum_{i=0}^{n} f_{n-i,m} \times (i!)^2 \times \binom{n}{i}\times \binom{n+m}{i}\)

时间复杂度和空间复杂度均为 \(O(n^2)\)

Solution 2

考虑一个重要的转化:题目中的 \((k!)^2\) 中方案我们可以分成两个步骤,第一个步骤是给每一个 \(A\) 中的 1 匹配一个 \(B\) 中的 1 ,第二个步骤是给这些情况重新排列。

现在我们更改一下 \(n,m\) 的定义,令 \(n\) 表示 \(A,B\) 上有一个位置为 1 的方案数, \(m\) 的意义同 Solution 1 中的不变。

考虑步骤一,那么如果我们对于 \(A\) 中的 1 的位置,向和它匹配的 \(B\) 中的位置连一条边,那么我们会发现,整张图被我们分成了若干条链和若干个环,链的个数恰好是 \(m\) 个,并且,链的选择是有顺序要求的,即必须从链首选到链尾,而环的选择则没有要求了,接下来我们考虑使用生成函数来表示这个东西,因为两条链或者两个环在组合的时候是需要乘上组合数的,所以考虑用指数型生成函数来解决。

链的指数型生成函数:(假设起点已经确定,所以在结束之后还需要乘上\(m!\)到答案中, \([x^i]F(x)\) 表示在链的端点之间有 \(i\) 个点的方案数。)
\[F(x)=\sum_{i=0}^{\infty} \frac{i!\times x^i}{i!\times (i+1)!} = \frac{e^x-1}{x}\]

环的指数型生成函数:(\([x^i]G(x)\)表示 \(i\) 个点的环的方案数。)
\[G(x)=\sum_{i=1}^{\infty} \frac{(i-1)!\times i!\times x^i}{i!\times i!} = -\ln(1-x)\]

所以我们需要将环和链组合起来,因为链的组合是有序的,而环的组合是无序的,所以最后的结果就是:
\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} ([x^i]F^m(x)) ([x^i]\exp(G(x)))\]

然后把函数带进去展开得到:
\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} [x^i](\frac{e^x-1}{x})^m\]

然后就可以直接计算答案了,时间复杂度 \(O(n\log n)\),空间复杂度 \(O(n)\)

Solution 1 的代码:

#include 
int quick_power(int a,int b,int Mod){
    int ans=1;
    while(b){
        if(b&1){
            ans=1ll*ans*a%Mod;
        }
        b>>=1;
        a=1ll*a*a%Mod;
    }
    return ans;
}
const int Maxn=10000;
const int Mod=998244353;
int f[Maxn+5][Maxn+5];
int n,k;
char a[Maxn+5],b[Maxn+5];
int s_1,s_2;
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
    frac[0]=1;
    for(int i=1;i<=Maxn;i++){
        frac[i]=1ll*frac[i-1]*i%Mod;
    }
    inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
    for(int i=Maxn-1;i>=0;i--){
        inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
    }
}
int C(int n,int m){
    return 1ll*frac[n]*inv_f[m]%Mod*inv_f[n-m]%Mod;
}
int main(){
    init();
    scanf("%s",a+1);
    scanf("%s",b+1);
    while(a[++n]!='\0');
    n--;
    for(int i=1;i<=n;i++){
        if(a[i]=='1'){
            k++;
        }
        if(a[i]=='1'&&b[i]=='1'){
            s_1++;
        }
        else if(a[i]=='1'){
            s_2++;
        }
    }
    f[0][0]=1;
    for(int i=0;i<=s_1;i++){
        for(int j=1;j<=s_2;j++){
            if(i==0&&j==0){
                continue;
            }
            f[i][j]=(f[i][j]+1ll*f[i][j-1]*j%Mod*j)%Mod;
            if(i>0){
                f[i][j]=(f[i][j]+1ll*f[i-1][j]*i%Mod*j)%Mod;
            }
        }
    }
    int ans=0;
    for(int i=0;i<=s_1;i++){
        ans=(ans+1ll*f[s_1-i][s_2]*frac[i]%Mod*frac[i]%Mod*C(s_1,i)%Mod*C(k,i))%Mod;
    }
    printf("%d\n",ans);
    return 0;
}

Solution 2 的代码:

#include 
#include 
#include 
using namespace std;
int quick_power(int a,int b,int Mod){
    int ans=1;
    while(b){
        if(b&1){
            ans=1ll*ans*a%Mod;
        }
        b>>=1;
        a=1ll*a*a%Mod;
    }
    return ans;
}
const int Maxn=40000;
const int G=3;
const int Mod=998244353;
int n,m,len;
char a[Maxn+5],b[Maxn+5];
void NTT(int *a,int flag,int n){
    static int R[Maxn+5];
    int len=1,L=0;
    while(len>1]>>1)|((i&1)<<(L-1));
    }
    for(int i=0;i>1);
    for(int i=0;i0;i--){
        a[i]=1ll*quick_power(i,Mod-2,Mod)*a[i-1]%Mod;
    }
    a[0]=0;
}
void find_ln(int *a,int *b,int n){
    static int c[Maxn+5];
    for(int i=0;i>1);
    find_ln(b,c,len);
    c[0]=(a[0]+1-c[0]+Mod)%Mod;
    for(int i=1;i=0;i--){
        inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
    }
}
int main(){
    init();
    scanf("%s",a+1);
    scanf("%s",b+1);
    while(a[++len]!='\0');
    for(int i=1;i<=len;i++){
        if(a[i]=='1'){
            n++;
            if(b[i]=='0'){
                m++;
            }
        }
    }
    for(int i=0;i<=n-m;i++){
        f[i]=inv_f[i+1];
    }
    int len=1;
    while(len<=n-m){
        len<<=1;
    }
    find_ln(f,g,len);
    memset(f,0,sizeof f);
    for(int i=0;i<=n-m;i++){
        f[i]=1ll*g[i]*m%Mod;
    }
    memset(g,0,sizeof g);
    find_exp(f,g,len);
    int ans=0;
    for(int i=0;i<=n-m;i++){
        f[i]=g[i];
        ans=(ans+f[i])%Mod;
    }
    ans=1ll*ans*frac[m]%Mod*frac[n-m]%Mod*frac[n]%Mod;
    printf("%d\n",ans);
    return 0;
}