The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange Permutations 题解



title: >-
The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange
Permutations
date: 2021-12-13 15:27:09
tags: [inclusion-exclusion, combinatorics, math, FFT, team training, merge]

题意

给一个全排列 \(P\) ,计算构造全排列 \(Q\) 使得 \(\forall i \in \{1, 2, \cdots, n - 1\}, Q_{i+1} \neq P_{Q_i}\) 的方案数

思路

抽象题意:取编号 \(1\) ~ \(n\) 的点出来,每个点上有一个值,表示不能连出的边,计算所有经过且仅经过 \(1\) 次每个顶点的有向路径(哈密顿路径)的方案数

A.png

(图中橙色边表示不可连,蓝色边表示可连)

考虑容斥,枚举破坏 \(i\) 个条件(有 \(i\) 橙色边)

由于是全排列,所以必然有若干个圈(含自环),每个 \(k\) 元环可贡献 \(0\) ~ \(k-1\) 个橙色边(因为每个点只经过一次,不可能形成回路),则贡献的生成函数为

\[\begin{aligned} &1 + C_k^1 \cdot x + C_k^2 \cdot x^2 + \cdots + C_k^{k-1} \cdot x^{k-1} \\ =\ & (1 + x) ^k - x^k \end{aligned} \]

然后找出所有环,启发式合并这些多项式即可

复杂度 \(O(n\ log^2\ n)\)

代码

#include 
#define ll long long
#define ull unsigned long long
#define i64 long long
#define poly std::vector
// dont visit a[m] when a.size() <= m
// (a = fastpow(c,n-m+1,m+1)).resize(m+1);
// i64 res = a[m] - b[m];
// (b = fastpow(d,n-m+1,m+1)).resize(m+1);

constexpr int MOD = 998244353;

namespace Poly { // remember to resize
    const int N = (1 << 21), g = 3;
    inline int power(int x, int p) {
        int res = 1;
        for (; p; p >>= 1, x = (ll)x * x % MOD) 
            if (p & 1)
                res = (ll)res * x % MOD;
        return res;
    }
    inline int fix(const int x) { return x >= MOD ? x - MOD : x; }
    void dft(poly& A, int n) {
        static ull W[N << 1], *H[30], *las = W, mx = 0;
        for (; mx < n; mx++) {
            H[mx] = las;
            ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1));
            for(int i=0;i<1<>= 1);
        }
        for (int k = 0, d = 1; k < n; k++, d <<= 1)
            for (int i = 0; i < (1 << n); i += (d << 1)) {
                ull *l = a + i, *r = a + i + d, *w = H[k], t;
                for (int j = 0; j < d; j++, l ++, r++) {
                    t = (*r) * (*w++) % MOD;
                    *r = *l + MOD - t, *l += t;
                }
            }
        for(int i=0;i<1<&a) { // return index
        std::priority_queue > H; // <-size, index>
        int n = a.size();
        for(int i=0;i=2) {
            int o1 = H.top().second; H.pop();
            int o2 = H.top().second; H.pop();
            poly res = mul(a[o1], a[o2]);
            a[o1].clear(); a[o2].clear();
            for(int i=0;i=MOD) x -= MOD;
    if(x<0) x += MOD;
}

int mul(int a,int b) {
    return 1ll * a * b % MOD;
}

int main(int argc, char const *argv[])
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr); std::cout.tie(nullptr);

    int n;
    std::cin >> n;
    std::vector p(n);
    for(int i=0;i> p[i];
        --p[i];
    }

    std::vector vis(n, false); // bool
    int circles = 0;
    std::vector cnt;
    for(int i=0;i ps(circles, poly());

    std::vector fac(n+1),ifac(n+1),inv(n+1);
    fac[0] = fac[1] = ifac[0] = ifac[1] = inv[0] = inv[1] = 1;
    for(int i=2;i<=n;++i) {
        fac[i] = mul(i, fac[i - 1]);
        inv[i] = mul(inv[MOD % i], MOD - MOD/i);
        ifac[i] = mul(inv[i], ifac[i - 1]);
    }

    auto C = [&](int n, int m) {
        return mul( fac[n], mul(ifac[m], ifac[n - m]) );
    };

    for(int i=0;i