CodeChef Expected Repetitions


CodeChef Expected Repetitions

? 记 \(\operatorname{power}(S)\) 为字符串 \(S\) 所有子串的幂值和,显然答案即为 \(\dfrac{2\operatorname{power}(S)}{N(N+1)}\)。考虑如何计算 \(\operatorname{power}(S)\)

? 我们将满足子串 \(T=R+R+\ldots+R+P\) ( 其中 \(P\)\(R\) 的前缀 ) 的字符串 \(R\) 称作为子串 \(T\) 的周期性前缀。将所有满足上述关系的 \((T,R)\) 称作为一对贡献组。换句话说,一个贡献组 \((T,R)\) 满足 \(T\)\(S\) 的子串且 \(R\)\(T\) 的周期性前缀。那么 \(\operatorname{power}(S)\) 即为所有贡献组 \((T,R)\)\(R\) 的权值之和。

? 考虑 \(S\) 中两个相等的子串 \(S\;\![\;\!a, b\;\!]=S\;\![\;\!c, d\;\!]\) ( 不妨设 \(b ) ,那么就有 \(S\;\![\;\!a,c)\)\(S\;\![\;\!a,d\;\!]\) 的周期性前缀,即 \(\left(S\;\![\;\!a,d\;\!],S\;\![\;\!a,c)\right)\) 是一对贡献组。容易发现,对于不同的 \(a,b;c,d\) 产生的贡献组是不同的;而若贡献组 \((T,R)\) 满足 \(|R|<|T|\),则该贡献组必定能被一组 \(a,b;c,d\) 生成。因此我们的想法是将所有的贡献组 \((T,R)\) 拆成两部分去统计:\(|R|=|T|\),和 \(|R|<|T|\)

? 对于 \(|R|=|T|\) 的贡献组是 trivial 的。重点是如何处理那些 \(|R|<|T|\) 的贡献组。由上述知,每一个 \(|R|<|T|\) 的贡献组都可以对应出一组相等的子串 \(S\;\![\;\!a,b\;\!]=S\;\![\;\!c,d\;\!]\),而该贡献组的贡献为子串 \(S\;\![\;\!a,c)\) 的权值。因此这一部分的贡献可以转化为所有满足 \(S\;\![\;\!a,b\;\!]=S\;\![\;\!c,d\;\!]\) 的四元组 \((a,b;c,d)\) 的贡献,而该贡献即为 \(S\;\![\;\!a,c)\) 的权值。

? 考虑换一种方式计算上述的贡献值。设 \(T\)\(S\) 的本质不同子串,记 \(T\)\(S\) 中出现的位置集合为 \(p\),那么子串 \(T\) 的贡献是 \(\sum\limits_{i=1}^{|p|}\sum\limits_{j=i+1}^{|p|}w(S\;\![\;\!p_i,p_j))\)

? 这里我们使用后缀自动机去解决这个问题。将原串 \(S\) 翻转后,本质不同子串 \(T\) 的贡献为 \(\sum\limits_{i=1}^{|\operatorname{edp}(T)|}\sum\limits_{j=i+1}^{|\operatorname{edp}(T)|}w(S\;\!(\operatorname{edp}_i,\operatorname{edp}_j\;\!])\),其中 \(\operatorname{edp}(T)\)\(T\)\(S\)\(\operatorname{endpos}\) 集合。于是我们建立出反串的 SAM 之后,用 dsu on tree 维护每一个节点内的 \(\operatorname{endpos}\) 集合。那么现在要实现的操作即为:

  1. 向当前集合 \(K\) 内插入一个位置 \(x\)
  2. 查询 \(\sum\limits_{i=1}^{|K|}\sum\limits_{j=i+1}^{|K|}w(K_i,K_j\;\!]\)
  3. 清空集合 \(K\)

\(W_i=w\;\![\;\!1,i\;\!]\) 那么查询式可以写成 \(\sum\limits_{i=1}^{|K|}W_{K_i}(i-1)-\sum\limits_{i=1}^{|K|}W_{K_i}(k-i)\)。于是我们只需要用树状数组去实现这个容器就行了。总时间复杂度为 \(\mathcal O(n\log ^2n)\)

Bonus\(\hat{S}\) 为字符串 \(S\) 反转后的结果,例如 \(\hat{\texttt{abc}}=\texttt{cba}\)。试证明:\(\operatorname{power}(S)=\operatorname{power}(\hat{S})\)

参考代码
#include 
using namespace std;
static constexpr int mod = 998244353, inv2 = (mod + 1) / 2;
inline int add(int x, int y, int M = mod) { return (x += y) >= M ? x - M : x; }
inline int sub(int x, int y, int M = mod) { return (x -= y) < 0 ? x + M : x; }
inline int mul(int x, int y, int M = mod) { return (int64_t)x * y % M; }
inline int &add_eq(int &x, int y, int M = mod) { return x = add(x, y, M); }
inline int &sub_eq(int &x, int y, int M = mod) { return x = sub(x, y, M); }
inline int &mul_eq(int &x, int y, int M = mod) { return x = mul(x, y, M); }
inline int qpow(int x, int y, int M = mod)
{ int r = 1; for (; y; y >>= 1, mul_eq(x, x, M)) (y & 1) && (mul_eq(r, x, M)); return r; }
static constexpr int Maxn = 1e6 + 10, MaxS = 26;
int n, W[26], pW[Maxn], ans1, ans2;
char str[Maxn];
namespace sam {
  struct state { int ch[MaxS], link, len; } tr[Maxn];
  int last, tot, sam[Maxn], isam[Maxn];
  inline void initialize(void) {
    memset(&tr[0], 0, sizeof(state));
    last = tot = 0; tr[0].link = -1; isam[0] = 0;
  } // initialize
  inline int newnode(void) {
    memset(&tr[++tot], 0, sizeof(state));
    isam[tot] = 0; return tot;
  } // newnode
  void extend(int c) {
    int p = last, cur = last = newnode(), r;
    sam[tr[cur].len = tr[p].len + 1] = cur; isam[cur] = tr[cur].len;
    for (; ~p && !tr[p].ch[c]; p = tr[p].link) tr[p].ch[c] = cur;
    if (p == -1) return ; int q = tr[p].ch[c];
    if (tr[q].len == tr[p].len + 1) tr[cur].link = q;
    else {
      tr[r = newnode()].len = tr[p].len + 1;
      memcpy(tr[r].ch, tr[q].ch, MaxS << 2);
      for (; ~p && tr[p].ch[c] == q; p = tr[p].link) tr[p].ch[c] = r;
      tr[r].link = tr[q].link, tr[q].link = tr[cur].link = r;
    }
  } // extend
} // namespace sam
using namespace sam;
struct Edge { int to, nxt; } e[Maxn];
int head[Maxn], etot;
inline void graph_initialize(int N) {
  memset(head, -1, (N + 1) << 2); etot = 0;
} // graph_initialize
void add_edge(int u, int v) {
  e[etot] = (Edge){v, head[u]};
  head[u] = etot++;
} // add_edge
int sz[Maxn], son[Maxn];
int dfn[Maxn], idfn[Maxn], ed[Maxn], dfn_index;
void sack_init(int u, int fa) {
  sz[u] = 1; son[u] = -1;
  idfn[dfn[u] = ++dfn_index] = u;
  for (int i = head[u], v; ~i; i = e[i].nxt)
    if ((v = e[i].to) != fa) {
      sack_init(v, u), sz[u] += sz[v];
      if (son[u] == -1 || sz[v] > sz[son[u]]) son[u] = v;
    }
  ed[u] = dfn_index;
} // sack_init
int b1[Maxn], b2[Maxn];
inline void mdf(int *b, int x, int w) {
  for (; x <= n; x += x & -x) add_eq(b[x], w);
} // mdf
inline int ask(int *b, int x) {
  int w = 0;
  for (; x; x -= x & -x) add_eq(w, b[x]);
  return w;
} // ask
inline int qry(int *b, int l, int r) {
  return sub(ask(b, r), ask(b, l - 1));
} // qry
int curAns;
void upd(int x, int type) {
  int wR = sub(qry(b2, x, n), mul(qry(b1, x, n), pW[x]));
  int wL = sub(mul(qry(b1, 1, x - 1), pW[x]), qry(b2, 1, x - 1));
  (type == 1 ? add_eq : sub_eq)(curAns, add(wL, wR), mod);
  mdf(b1, x, type == 1 ? 1 : -1);
  mdf(b2, x, (type == 1 ? add : sub)(0, pW[x], mod));
} // upd
void sack(int u, int fa, bool keep) {
  for (int i = head[u], v; ~i; i = e[i].nxt)
    if ((v = e[i].to) != fa) if (v != son[u])
      sack(v, u, false);
  if (son[u] != -1) sack(son[u], u, true);
  for (int i = head[u], v; ~i; i = e[i].nxt)
    if ((v = e[i].to) != fa) if (v != son[u])
      for (int j = dfn[v]; j <= ed[v]; ++j)
        if (isam[idfn[j]] != 0) upd(isam[idfn[j]], 1);
  if (isam[u] != 0) upd(isam[u], 1);
  if (u != 0) add_eq(ans1, mul(curAns, tr[u].len - tr[tr[u].link].len));
  if (keep == false)
    for (int j = dfn[u]; j <= ed[u]; ++j)
      if (isam[idfn[j]] != 0) upd(isam[idfn[j]], -1);
} // sack
int main(void) {
  int tests; scanf("%d", &tests);
  while (tests--) {
    scanf("%s", str + 1); n = strlen(str + 1);
    for (int i = 0; i < 26; ++i) scanf("%d", &W[i]), W[i] %= mod;
    for (int i = 1; i <= n; ++i) pW[i] = add(pW[i - 1], W[str[i] - 'a']);
    sam::initialize();
    for (int i = 1; i <= n; ++i) sam::extend(str[i] - 'a');
    graph_initialize(tot);
    for (int i = 1; i <= tot; ++i) add_edge(tr[i].link, i);
    dfn_index = 0; sack_init(0, -1);
    ans1 = ans2 = 0; sack(0, -1, false);
    for (int i = 1, pS = 0; i <= n; ++i)
      add_eq(ans2, sub(mul(i, pW[i]), pS)), add_eq(pS, pW[i]);
    int num = add(ans1, ans2), denom = mul(mul(n, n + 1), inv2);
    printf("%d\n", mul(num, qpow(denom, mod - 2)));
  } exit(EXIT_SUCCESS);
} // main