Solution -「LOJ #6538」烷基计数 加强版 加强版


\(\mathscr{Description}\)

??Link.

??求含 \(n\) 个结点、无标号有根、结点儿子数量不超过 \(3\) 的树的数量。答案模 \(998244353\)

??\(n\le10^5\)

\(\mathscr{Solution}\)

??感觉 Burnside 被用在了 corner 的地方,或许就是一道 GF 上手训练叭。

??设 \([x^k]f(x)\) 表示 \(n=k\) 时问题的答案(就喜欢小写 \(f\) qwq),钦定 \([x^0]f(x)=1\)。转移即求三个可空的子树在三阶置换群下的等价类数目,Burnside 一发得到

\[f(x)=1+\frac{1}{6}x(f^3(x)+3f(x)f(x^2)+2f(x^3)). \]

??巧妙 trick:牛迭的时候算了 \([x^{0..n-1}]f(x)\),那么 \(f(x^2)\)\(f(x^3)\) 可以当做常多项式。具体地,我们想要解多项式方程 \(p(u,x)=f(x)-u=0\),那么 \(p(u,x)\) 在倍增时“等价”为

\[p(u,x)=\frac{1}{6}xu^3+\frac{1}{2}xuf(x^2)+\frac{1}{3}xf(x^3)+1-u. \]

因而

\[u_{2n} = u_n-\frac{p(u_n,x)}{\frac{\text d}{\text du}p(u_n,x)}, \]

其中

\[\frac{\text d}{\text dx}p(u_n,x)=\frac{1}{2}xu^2+\frac{1}{2}xf(x^2)-1. \]

所以多项式求逆 + 牛迭即可。注意 \(\operatorname{dft}_n(x)=\lang w_n^{0..n-1}\rang\),节约一下 DFT 次数。复杂度是 \(\mathcal O(n\log n)\) 的。

\(\mathscr{Code}\)

/*+Rainybunny+*/

#include 

#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)

const int MAXN = 1e5, MAXL = 1 << 18, MOD = 998244353;
const int INV6 = 166374059, INV2 = 499122177, INV3 = 332748118;

inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline int sqr(const int u) { return mul(u, u); }
inline int cub(const int u) { return mul(u, mul(u, u)); }
inline void subeq(int& u, const int v) { (u -= v) < 0 && (u += MOD); }
inline int sub(int u, const int v) { return (u -= v) < 0 ? u + MOD : u; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int add(int u, const int v) { return (u += v) < MOD ? u : u - MOD; }
inline int mpow(int u, int v) {
    int ret = 1;
    for (; v; u = mul(u, u), v >>= 1) ret = mul(ret, v & 1 ? u : 1);
    return ret;
}

namespace PolyOper {

const int MG = 3;
int omega[19][MAXL];

inline void init() {
    rep (i, 1, 18) {
        int* oi = omega[i]; oi[0] = 1, oi[1] = mpow(MG, MOD - 1 >> i);
        rep (j, 2, (1 << i) - 1) oi[j] = mul(oi[j - 1], oi[1]);
    }
}

inline void ntt(const int n, int* u, const int type) {
    static int rev[MAXL], lasn = -1;
    if (lasn != n) {
        lasn = n;
        rep (i, 1, n - 1) rev[i] = rev[i >> 1] >> 1 | (i & 1) * n >> 1;
    }
    rep (i, 1, n - 1) if (i < rev[i]) u[i] ^= u[rev[i]] ^= u[i] ^= u[rev[i]];
    for (int i = 1, stp = 1; stp < n; ++i, stp <<= 1) {
        int* oi = omega[i];
        for (int j = 0; j < n; j += stp << 1) {
            rep (k, j, j + stp - 1) {
                int x = u[k], y = mul(oi[k - j], u[k + stp]);
                u[k] = add(x, y), u[k + stp] = sub(x, y);
            }
        }
    }
    if (!~type) {
        std::reverse(u + 1, u + n);
        int iv = MOD - (MOD - 1) / n;
        rep (i, 0, n - 1) u[i] = mul(u[i], iv);
    }
}

inline void pinv(const int n, const int* u, int* r) {
    if (n == 1) return assert(u[0]), r[0] = mpow(u[0], MOD - 2), void();
    pinv(n >> 1, u, r);
    static int tmp[2][MAXL];
    rep (i, 0, (n >> 1) - 1) tmp[0][i] = r[i], tmp[1][i] = u[i];
    rep (i, n >> 1, n - 1) tmp[0][i] = 0, tmp[1][i] = u[i];
    rep (i, n, (n << 1) - 1) tmp[0][i] = tmp[1][i] = 0;
    ntt(n << 1, tmp[0], 1), ntt(n << 1, tmp[1], 1);
    rep (i, 0, (n << 1) - 1) {
        tmp[0][i] = mul(tmp[0][i], sub(2, mul(tmp[0][i], tmp[1][i])));
    }
    ntt(n << 1, tmp[0], -1);
    rep (i, 0, n - 1) r[i] = tmp[0][i];
}

inline void _test_() {
#ifdef RYBY
    int tmp[] = { 1, 6, 3, 4, 9, 0, 0, 0 }, res[8];
    pinv(8, tmp, res);
    rep (i, 0, 7) printf("%d%c", res[i], i < 7 ? ' ' : '\n');
#endif
}

} using namespace PolyOper;

int n, f[MAXL], g[MAXL], dg[MAXL], rg[MAXL], sf[MAXL], cf[MAXL];

inline void solve(const int len) {
    if (len == 1) return void(f[0] = 1);
    solve(len >> 1);

    rep (i, 0, len - 1) g[i] = dg[i] = sf[i] = cf[i] = 0;
    rep (i, 0, (len - 1) / 2) sf[i << 1] = f[i];
    rep (i, 0, (len - 1) / 3) cf[i * 3] = f[i];
    ntt(len << 1, f, 1), ntt(len << 1, sf, 1), ntt(len << 1, cf, 1);

    int wn = 32 - __builtin_clz(len);
    rep (i, 0, (len << 1) - 1) {
        int wi = omega[wn][i];
        g[i] = sub(mul(wi, add(mul(INV6, cub(f[i])),
          add(mul(INV2, mul(f[i], sf[i])), mul(INV3, cf[i])))), sub(f[i], 1));
        dg[i] = sub(mul(mul(INV2, wi), add(sqr(f[i]), sf[i])), 1);
    }
    ntt(len << 1, g, -1), ntt(len << 1, dg, -1);
    rep (i, len, (len << 1) - 1) g[i] = dg[i] = 0;

    pinv(len, dg, rg);
    ntt(len << 1, g, 1), ntt(len << 1, rg, 1);
    rep (i, 0, (len << 1) - 1) subeq(f[i], mul(g[i], rg[i]));
    ntt(len << 1, f, -1);
    rep (i, len, (len << 1) - 1) f[i] = 0;
}

int main() {
    scanf("%d", &n), init();
    // _test_();
    solve(1 << 32 - __builtin_clz(n));
    printf("%d\n", f[n]);
    return 0;
}

相关