P4199 万径人踪灭 题解


Description

Luogu传送门

Solution

好名字。

像我这种懒狗肯定是直接看简化版题意啊!

题意让我们求不连续的回文子序列个数,我们发现并不是很好求,所以转换一下思路,求:

回文子序列个数(可连续) - 回文子串个数

回文子串个数很容易求出来,就是 \(manacher\) 板子,下面我们来考虑如何求出回文子序列个数。

设点 \(id\) 为一个子序列的对称中心,若 \(S_{i - j} = S_{i + j}\) ,我们的答案就会多一种选择,假设对于对称中心 \(i\)\(f_i = x\) 种选择,那么:

  1. \(i\) 是一个字符,答案就是 \(2^{x + 1} - 1\)。包括对称中心 \(i\) 在内有 \(x + 1\) 对选或不选的字符,所以一共有 \(2^{x + 1}\) 种情况,减去空集的 1 种情况就是答案。
  2. \(i\) 在两个字符中间,答案就是 \(2^x - 1\),即不能包括对称中心。

通过上面的讨论可以发现这个对称中心在两个字符中间,所以我们按照 \(manacher\) 的做法把原字符串翻倍,这样一来这个中心就一定是整点了。

那么这个 \(f_i\) 如何求呢?

构造出生成函数:

\[F(x) = \sum\limits_{i = 0}^{|s| - 1}a_ix^i \]

然后令 \(F(x)\) 自己卷自己,即 \(F(x)^2\), 这一步用 NTT 或 FFT 都行。

不难发现,\(F(x)^2\)\(x^i\) 的系数就是在原字符串中关于 \(\frac{i}{2}\) 对称的字符对的个数。

又观察到输入的字符串中只有 ab ,而且这两个字符之间没有关联,所以分开考虑:

  1. 对于 a,若 \(S_i = a\),那么 \(A_i = 1\),反之 \(A_i = 0\)
  2. 对于 b 同理,若 \(S_i = b\),那么 \(B_i = 1\),反之 \(B_i = 0\)

得出 \(A\)\(B\) 之后我们就可以计算 \(f_i\) 啦。

\[f_i = A^2 + B^2 \]

最后统计一下答案就完啦。

Code

我这里用的 NTT 做的,需要两个模数,注意写的时候不要弄混了 QwQ

#include 
#define ll long long
#define cl const ll

using namespace std;

namespace IO{
    inline ll read(){
        ll x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }

    template  inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;

const ll N = 3e5 + 10;
const ll mod = 998244353, Mod = 1e9 + 7;
const ll G = 3, Gi = 332748118;
ll n, m, ans1, ans2;
ll a[N], b[N], c[N], d[N], rev[N];
int p[N];
char s[N], t[N];

namespace NTT{
    ll lim, len;

    inline ll qpow(ll a, ll b, ll mod){
        ll res = 1;
        while(b){
            if(b & 1) res = res * a % mod;
            a = a * a % mod, b >>= 1;
        }
        return res;
    }

    inline void get_rev(cl n){
        lim = 1, len = 0;
        while(lim < n) lim <<= 1, ++len;
        for(int i = 0; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }

    inline void ntt(ll A[], cl lim, cl type){
        for(int i = 0; i < lim; ++i)
            if(i < rev[i]) swap(A[i], A[rev[i]]);
        for(int mid = 1; mid < lim; mid <<= 1){
            ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1), mod);
            for(int i = 0; i < lim; i += (mid << 1)){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                    ll x = A[i + j], y = w * A[i + j + mid] % mod;
                    A[i + j] = (x + y) % mod;
                    A[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == 1) return;
        ll inv = qpow(lim, mod - 2, mod);
        for(int i = 0; i < lim; ++i) A[i] = A[i] * inv % mod;
    }

    inline void Mul(cl n, cl m, ll a[], ll b[]){
        get_rev(n + m);
        ntt(a, lim, 1), ntt(b, lim, 1);
        for(int i = 0; i < lim; ++i) a[i] = a[i] * b[i] % mod;
        ntt(a, lim, -1);
    }
}
using namespace NTT;

inline void manacher(){
    n = strlen(t + 1);
    s[0] = '#', s[(n << 1) + 1] = '*';
    for(int i = 1; i <= n; ++i)
        s[(i << 1) - 1] = '*', s[i << 1] = t[i];
    m = (n << 1) + 1;
    int mx = 0, id = 0;
    for(int i = 1; i <= m; ++i){
        if(i < mx) p[i] = min(mx - i, p[(id << 1) - i]);
        else p[i] = 1;
        while(i - p[i] >= 1 && i + p[i] <= m && s[i - p[i]] == s[i + p[i]]) ++p[i];
        if(i + p[i] > mx) id = i, mx = i + p[i];
    }
    for(int i = 1; i <= (n << 1) + 1; ++i) ans1 = (ans1 + (p[i] >> 1)) % Mod;
}

inline void solve(){
    for(int i = 1; i <= n; ++i){
        if(t[i] == 'a') a[i - 1] = c[i - 1] = 1;
        else b[i - 1] = d[i - 1] = 1;
    }
    Mul(n, n, a, c), Mul(n, n, b, d);
    for(int i = 0; i <= (n << 1); ++i) c[i] = (a[i] + b[i]) % mod;
    for(int i = 0; i <= (n << 1) - 2; ++i){
        if(i & 1) ans2 = (ans2 + qpow(2, (c[i] >> 1), Mod) - 1 + Mod) % Mod;
        else ans2 = (ans2 + qpow(2, (c[i] >> 1) + 1, Mod) - 1 + Mod) % Mod;
    }
    write((ans2 - ans1 + Mod) % Mod), puts("");
}

int main(){
    scanf("%s", t + 1);
    manacher();
    solve();
    return 0;
}

\[\_EOF\_ \]

相关