Educational Codeforces Round 118 (Rated for Div. 2) - F. Tree Coloring 题解



title: Codeforces-Edu118(Div.2)F. Tree Coloring
date: 2021-12-12 23:17:43
tags: [codeforces,div2,cpp,problem F,fft,divide and conquer,merge]

题意

给定一棵树,要求计算,给节点染色,要求每个节点 \(c_k \neq c_{p_k} - 1\) ,统计方案数 \((mod\ \ 998\ 244\ 353)\)

思路

容斥枚举破坏 \(i\) 个条件下的方案数,对于每个节点,都有出度种方法造成 \(1\) 个贡献,对于每个节点的生成函数即为

\[g(x) = 1 + c \cdot x \]

其余节点染色 \((n-i)!\) 排列一下就好

启发式合并或分治NTT即可

代码

#include 
#define ll long long
#define ull unsigned long long
#define i64 long long
#define poly std::vector
// dont visit a[m] when a.size() <= m
// (a = fastpow(c,n-m+1,m+1)).resize(m+1);
// i64 res = a[m] - b[m];
// (b = fastpow(d,n-m+1,m+1)).resize(m+1);

constexpr int MOD = 998244353;

namespace Poly { // remember to resize
    const int N = (1 << 21), g = 3;
    inline int power(int x, int p) {
        int res = 1;
        for (; p; p >>= 1, x = (ll)x * x % MOD) 
            if (p & 1)
                res = (ll)res * x % MOD;
        return res;
    }
    inline int fix(const int x) { return x >= MOD ? x - MOD : x; }
    void dft(poly& A, int n) {
        static ull W[N << 1], *H[30], *las = W, mx = 0;
        for (; mx < n; mx++) {
            H[mx] = las;
            ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1));
            for(int i=0;i<1<>= 1);
        }
        for (int k = 0, d = 1; k < n; k++, d <<= 1)
            for (int i = 0; i < (1 << n); i += (d << 1)) {
                ull *l = a + i, *r = a + i + d, *w = H[k], t;
                for (int j = 0; j < d; j++, l ++, r++) {
                    t = (*r) * (*w++) % MOD;
                    *r = *l + MOD - t, *l += t;
                }
            }
        for(int i=0;i<1<=MOD) x -= MOD;
    if(x<0) x += MOD;
}

int main(int argc, char const *argv[])
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr); std::cout.tie(nullptr);

    int n;
    std::cin >> n;
    std::vector > g(n, std::vector());
    for(int i=0;i> u >> v;
        --u; --v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    auto dnc = [&](auto dnc,int l,int r) {
        if(r - l == 1) {
            return (poly) {1, (int) g[l].size() - (l != 0)};
        }

        int mid = l + r >> 1;
        return mul(dnc(dnc,l,mid), dnc(dnc,mid,r));
    };

    int res = 0;

    poly ans = dnc(dnc,0,n);
    ans.resize(n+1);

    std::vector fac(n+1);
    fac[0] = fac[1] = 1;
    for(int i=2;i<=n;++i) {
        fac[i] = 1ll * fac[i-1] * i % MOD;
    }

    for(int i=0;i<=n;++i) {
        int thiz = 1ll * fac[n - i] * ans[i] % MOD;
        norm(
            res += (i&1 ? MOD - thiz : thiz)
        );
    }
    
    std::cout << res;
    
    return 0;
}