Luogu P7279 光棱碎片


Luogu P7279 光棱碎片

? 首先可以差分将限制转化为 \((a_{r_1}\oplus a_{r_2})+(r_1-l_1+1)\le k\)

? 将 \(\texttt{SAM}\) 建出来后对于每个本质不同子串的 \(\text{endpos}\) 考虑。设点 \(x_1,x_2\) 分别对应原序列中 \(r_1,r_2\)\(\texttt{parent tree}\) 上的位置,设 \(y=\operatorname{lca}(x_1,x_2)\) 那么点对 \(x_1,x_2\) 的贡献为 \(\sum\limits_{i=1}^{\operatorname{len}_y}[(a_{r_1}\oplus a_{r_2})+i\le k]\)

? 注意到我们要统计所有的 \(\operatorname{endpos}\) 点对,可以考虑使用 \(\texttt{dsu on tree}\) 优化。于是你要维护一个数据结构,实现以下操作:

  1. 向当前容器 \(S\) 内加入一个数 \(x\)
  2. 查询 \(\sum\limits_{y\in S}\max\{0,\min\{k-(x\oplus y),d\}\}\),其中 \(k,d\) 为两个定值。

这个查询其实还有点阴间。我们记 \(c_k=\sum\limits_{y\in S}[x\oplus y\le k],g_k=\sum\limits_{y\in S}[x\oplus y\le k](x\oplus y)\),那么就有 \(A=\sum\limits_{x\oplus y\le k}\min\{k-(x\oplus y),d\}=d\cdot c_{k-d}+k\cdot (c_{k}-c_{k-d})-(g_{k}-g_{k-d})\)。于是我们只需考虑如何求出 \(c_k,g_k\)

? 建出 \(\texttt{01trie}\)。查询 \(c_k\) 是基操,不用多说;而查询 \(g_k\) 时,只需要在 \(\texttt{trie}\) 的每个结点上拆位维护每一位 \(1\) 的个数即可。

? 总时间复杂度为 \(\mathcal O(n\log n\log ^2V)\),空间复杂度为 \(\mathcal O(n\log ^2V)\)。由于 \(\texttt{dsu on tree}\)\(\texttt{01trie}\) 的常数都很小,就过了。

参考代码

#include 
using namespace std;
static constexpr int mod = 998244353;
inline int add(int x, int y) { return x += y - mod, x + (x >> 31 & mod); }
inline int sub(int x, int y) { return x -= y, x + (x >> 31 & mod); }
inline int mul(int x, int y) { return (int64_t)x * y % mod; }
inline void add_eq(int &x, int y) { x += y - mod, x += (x >> 31 & mod); }
inline void sub_eq(int &x, int y) { x -= y, x += (x >> 31 & mod); }
inline void mul_eq(int &x, int y) { x = (int64_t)x * y % mod; }
static constexpr int Maxn = 2e5 + 5, MaxS = 26;
int n, en, head[Maxn], dn, ans;
int wl, wr, w[Maxn];
char str[Maxn];
struct Edge { int to, nxt; } e[Maxn];
void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; }
struct state { int ch[MaxS], link, len; } tr[Maxn];
int last, sn, edp[Maxn], iedp[Maxn];
void extend(int c) {
  int p = last, cur = last = ++sn, r;
  edp[tr[cur].len = tr[p].len + 1] = cur; iedp[cur] = tr[cur].len;
  for (; ~p && !tr[p].ch[c]; p = tr[p].link) tr[p].ch[c] = cur;
  if (p == -1) return ; int q = tr[p].ch[c];
  if (tr[q].len == tr[p].len + 1) tr[cur].link = q;
  else {
    tr[r = ++sn].len = tr[p].len + 1;
    memcpy(tr[r].ch, tr[q].ch, MaxS << 2);
    for (; ~p && tr[p].ch[c] == q; p = tr[p].link) tr[p].ch[c] = r;
    tr[r].link = tr[q].link, tr[q].link = tr[cur].link = r;
  }
} // extend
namespace trie {
  static constexpr int LOG = 17;
  struct node {
    int ch[2], c, s[LOG];
    node() = default;
  } tr[Maxn * LOG * 2];
  int tn = 1;
  inline int newnode(void) {
    return tr[++tn] = node(), tn;
  } // trie::newnode
  void insert(int w) {
    int p = 1; tr[p].c++;
    for (int k = 0; k < LOG; ++k)
      tr[p].s[k] += (w >> k & 1);
    for (int i = LOG - 1; i >= 0; --i) {
      int dir = w >> i & 1;
      if (!tr[p].ch[dir])
        tr[p].ch[dir] = newnode();
      p = tr[p].ch[dir]; tr[p].c++;
      for (int k = 0; k < LOG; ++k)
        tr[p].s[k] += (w >> k & 1);
    }
  } // trie::insert
  pair ask(int w, int r) {
    if (r < 0) return {0, 0};
    int p = 1, c = 0, s = 0;
    for (int i = LOG - 1; i >= 0 && p; --i) {
      int dir = ((w ^ r) >> i & 1) ^ 1;
      if (r >> i & 1) {
        c += tr[tr[p].ch[dir]].c;
        for (int k = 0; k < LOG; ++k) {
          int cs = (w >> k & 1)
            ? tr[tr[p].ch[dir]].c - tr[tr[p].ch[dir]].s[k]
            : tr[tr[p].ch[dir]].s[k];
          add_eq(s, ((int64_t)cs << k) % mod);
        }
      }
      p = tr[p].ch[dir ^ 1];
    }
    c += tr[p].c;
    for (int k = 0; k < LOG; ++k) {
      int cs = (w >> k & 1) ? tr[p].c - tr[p].s[k] : tr[p].s[k];
      add_eq(s, ((int64_t)cs << k) % mod);
    }
    return {c, s};
  } // trie::ask
} // namespace trie
inline int ask(int w, int r, int len) {
  auto r1 = trie::ask(w, r - len), r2 = trie::ask(w, r);
  return add(mul(r1.first, len), sub(mul(r2.first - r1.first, r), sub(r2.second, r1.second)));
} // ask
inline int query(int w, int len) { return sub(ask(w, wr, len), ask(w, wl - 1, len)); }
int sz[Maxn], hson[Maxn], dep[Maxn], dfn[Maxn], idfn[Maxn];
void sack_init(int u, int depth) {
  dep[u] = depth; sz[u] = 1, hson[u] = -1; idfn[dfn[u] = ++dn] = u;
  for (int i = head[u], v; i; i = e[i].nxt) {
    sack_init(v = e[i].to, depth + 1), sz[u] += sz[v];
    if (hson[u] == -1 || sz[v] > sz[hson[u]]) hson[u] = v;
  }
} // sack_init
void sack(int u, bool keep) {
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != hson[u]) sack(v, false);
  if (hson[u] != -1) sack(hson[u], true);
  if (iedp[u] != 0) {
    add_eq(ans, query(w[iedp[u]], tr[u].len));
    trie::insert(w[iedp[u]]);
  }
  for (int i = head[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != hson[u]) {
      for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i)
        if (iedp[x = idfn[i]] != 0) add_eq(ans, query(w[iedp[x]], tr[u].len));
      for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i)
        if (iedp[x = idfn[i]] != 0) trie::insert(w[iedp[x]]);
    }
  if (!keep) trie::tr[trie::tn = 1] = trie::node();
} // sack
int main(void) {
  scanf("%d%s", &n, str + 1);
  last = sn = 0, tr[0].link = -1;
  for (int i = 1; i <= n; ++i) extend(str[i] - 'a');
  for (int i = 1; i <= sn; ++i) add_edge(tr[i].link, i);
  for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
  scanf("%d%d", &wl, &wr);
  dn = 0, sack_init(0, 0); sack(0, false);
  printf("%d\n", ans);
  exit(EXIT_SUCCESS);
} // main