「HAOI2016」找相同字符


知识点: SA,线段树,广义 SAM

原题面 Loj Luogu

给定两字符串 \(S_1, S_2\),求出在两字符串中各取一个子串,使得这两个子串相同的方案数。
两方案不同当且仅当这两个子串中有一个位置不同。
\(1\le |S_1|, |S_2|\le 2\times 10^5\)

分析题意

线段树

考察对 \(\operatorname{lcp}\) 单调性的理解。


\(S_1\) 加个终止符,\(S_2\) 串扔到 \(S_1\) 后面,跑 SA。
显然,答案即后半段的后缀,与前半段的后缀的所有 \(\operatorname{lcp}\) 之和。


按字典序枚举后半段的后缀,设当前枚举到的后缀为 \(sa_i\)
仅考虑 字典序 \( 的 前半段的后缀 \(sa_j\ (j,其对 \(sa_i\) 的贡献为 \(\operatorname{lcp}(sa_i, sa_j)\)

\(\operatorname{lcp}\) 的单调性,当枚举到 第一个 \(>sa_i\)后半段的后缀 \(sa_k\ (k>i)\) 时,有 :\(\operatorname{lcp}(sa_{k}, sa_j)\le \operatorname{lcp}(sa_i,sa_j)\)

  1. \(\operatorname{lcp}(sa_{k}, sa_j)< \operatorname{lcp}(sa_i,sa_j)\),则 \(sa_j\)\(sa_k\) 的贡献应变为 \(\operatorname{lcp}(sa_k, sa_j) = \min\{\operatorname{lcp}(sa_i,sa_j), \min\limits_{l=i+1}^{k}{\{\operatorname{height}_l}\}\}\)

  2. 若存在 \(sa_l, l\in (i,k)\)前半段的后缀 时,作出贡献的元素增加。

考虑在枚举后缀的过程中,用权值线段树维护 字典序 \(前半段 的后缀 \(sa_j\ (j 的不同长度的 \(\operatorname{lcp}\) 的数量。
上述两操作,即为区间赋值 与 单点插入。


再按字典序倒序枚举后缀,计算字典序 \(>sa_i\) 的 前半段的后缀的贡献。
分析很屑,代码有详细注释。

复杂度 \(O(n\log n)\)
线段树写法比较无脑,也可以单调栈简单维护,复杂度也为\(O(n\log n)\) 级别。


广义 SAM

用两个字符串构造广义 SAM。
维护每个状态维护了几个串的 \(\operatorname{endpos}\)
当一个状态同时维护了两个串的 \(\operatorname{endpos}\),则该状态及其 parent 树上的祖先 所代表的串,均为公共子串。

\(size(u,0/1)\) 表示状态 \(u\) 维护了串 1/2 的 \(\operatorname{endpos}\) 的个数,有:

\[\sum size(i,0)\times size(i,1)\times (\operatorname{len}(i)-\operatorname{len}(\operatorname{link}(i)) \]

具体地,每插入串 \(i\) 的一个新字符,就对该字符对应的状态的 \(size(i) +1\)
在 parent 树上求子树 \(size\) 和,最后枚举状态更新答案。


爆零小技巧 1:在有返回值的函数中不写 return /se
爆零小技巧 2:边界玄学怎么办?判断正确性可靠对拍实现。


代码实现

广义 SAM

//知识点:SAM
/*
By:Luckyblock
试了试变量写法,挺清爽的。
*/
#include 
#include 
#include 
#include 
#define ll long long
const int kMaxn = 1e6 + 10;
const int kMaxm = 26;
//=============================================================
ll ans;
char S[kMaxn];
int size[kMaxn][2], id[kMaxn], cnt[kMaxn];
int num, node_num = 1, ch[kMaxn << 1][kMaxm], len[kMaxn <<1], link[kMaxn << 1];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
int Insert(int c_, int last_) {
  if (ch[last_][c_]) {
    int p = last_, q = ch[p][c_];
    if (len[p] + 1 == len[q]) return q;
    int newq = ++ node_num;
    memcpy(ch[newq], ch[q], sizeof(ch[q])); 
    len[newq] = len[p] + 1; 
    link[newq] = link[q];
    link[q] = newq; 
    for (; p && ch[p][c_] == q; p = link[p]) ch[p][c_] = newq;
    return newq;
  }
  int p = last_, now = ++ node_num;
  len[now] = len[p] + 1;
  for (; p && ! ch[p][c_]; p = link[p]) ch[p][c_] = now;
  if (! p) {link[now] = 1; return now;} 
  int q = ch[p][c_];
  if (len[q] == len[p] + 1) {link[now] = q; return now;}
  int newq = ++ node_num;
  memcpy(ch[newq], ch[q], sizeof(ch[q])); 
  link[newq] = link[q], len[newq] = len[p] + 1; 
  link[q] = link[now] = newq;
  for (; p && ch[p][c_] == q; p = link[p]) ch[p][c_] = newq;
  return now;
}
//=============================================================
int main() {
  for (; num <= 1; ++ num) {
    scanf("%s", S + 1);
    int n = strlen(S + 1), last = 1;
    for (int i = 1; i <= n; ++ i) {
      last = Insert(S[i] - 'a', last);
      size[last][num] = 1;
    }
  }
  for (int i = 1; i <= node_num; ++ i) cnt[len[i]] ++;
  for (int i = 1; i <= node_num; ++ i) cnt[i] += cnt[i - 1];
  for (int i = 1; i <= node_num; ++ i) id[cnt[len[i]] --] = i;
  for (int i = node_num; i >= 2; -- i) {
    int now = id[i];
    size[link[now]][0] += size[now][0];
    size[link[now]][1] += size[now][1];
  }
  for (int i = 2; i <= node_num; ++ i) {
    ans += 1ll * size[i][0] * size[i][1] * (len[i] - len[link[i]]);
  }
  printf("%lld\n", ans);
  return 0; 
}

傻逼线段树

//知识点:SA
/*
By:Luckyblock 
*/
#include 
#include 
#include 
#include 
#define ll long long
#define lson (now_<<1)
#define rson (now_<<1|1)
const int kMaxn = 4e5 + 10;
//=============================================================
char S[kMaxn];
int n1, n, m, sa[kMaxn], rk[kMaxn << 1], oldrk[kMaxn << 1], height[kMaxn];
int id[kMaxn], cnt[kMaxn], rkid[kMaxn];
ll ans, size[kMaxn << 2], sum[kMaxn << 2]; //size 维护数量,sum 维护 lcp 之和。
bool tag[kMaxn << 2];
//=============================================================
inline int read() {
  int f = 1, w = 0; char ch = getchar();
  for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
bool cmp(int x, int y, int w) { //判断两个子串是否相等。
  return oldrk[x] == oldrk[y] && 
         oldrk[x + w] == oldrk[y + w]; 
}
void GetHeight() {
  for (int i = 1, k = 0; i <= n; ++ i) {
    if (rk[i] == 1) k = 0;
    else {
      if (k > 0) k --;
      int j = sa[rk[i] - 1];
      while (i + k <= n && j + k <= n && 
             S[i + k] == S[j + k]) {
        ++ k;
      }
    }
    height[rk[i]] = k;
  }
}
void SuffixSort() {
  m = std :: max(n, 300);
  for (int i = 1; i <= n; ++ i) ++ cnt[rk[i] = S[i]];
  for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
  for (int i = n; i >= 1; -- i) sa[cnt[rk[i]] --] = i;
  for (int p, w = 1; w < n; w <<= 1) {
    p = 0;
    for (int i = n; i > n - w; -- i) id[++ p] = i;
    for (int i = 1; i <= n; ++ i) {
      if (sa[i] > w) id[++ p] = sa[i] - w;
    }
    memset(cnt, 0, sizeof (cnt));
    for (int i = 1; i <= n; ++ i) ++ cnt[(rkid[i] = rk[id[i]])];
    for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
    for (int i = n; i >= 1; -- i) sa[cnt[rkid[i]] --] = id[i];
    std ::swap(rk, oldrk);
    m = 0;
    for (int i = 1; i <= n; ++ i) {
      m += (cmp(sa[i], sa[i - 1], w) ^ 1);
      rk[sa[i]] = m;
    }
  }
  GetHeight();
}
void Build(int now_, int L_, int R_) {
  size[now_] = sum[now_] = 0ll;
  tag[now_] = false;
  if (L_ == R_) return ;
  int mid = (L_ + R_) >> 1;
  Build(lson, L_, mid), Build(rson, mid + 1, R_);
}
void Pushdown(int now_) {
  tag[lson] = tag[rson] = true;
  size[lson] = size[rson] = 0;
  sum[lson] = sum[rson] = 0;
  tag[now_] = false;
}
void Pushup(int now_) {
  size[now_] = size[lson] + size[rson];
  sum[now_] = sum[lson] + sum[rson];
}
ll Delete(int now_, int L_, int R_, int ql_, int qr_) {
  if (ql_ <= L_ && R_ <= qr_) {
    ll ret = size[now_];
    tag[now_] = true;
    size[now_] = sum[now_] = 0;
    return ret;
  }
  if(tag[now_]) Pushdown(now_);
  int mid = (L_ + R_) >> 1;
  ll ret = 0ll;
  if (ql_ <= mid) ret += Delete(lson, L_, mid, ql_, qr_);
  if (qr_ > mid) ret += Delete(rson, mid + 1, R_, ql_, qr_);
  Pushup(now_);
  return ret;
}
void Insert(int now_, int L_, int R_, int pos_, ll num) {
  if (! num) return ;
  if (L_ == R_) {
    size[now_] += num;
    sum[now_] += 1ll * num * (L_ - 1ll); //注意减去偏移量。
    return ;
  }
  if (tag[now_]) Pushdown(now_);
  int mid = (L_ + R_) >> 1;
  if (pos_ <= mid) Insert(lson, L_, mid, pos_, num);
  else Insert(rson, mid + 1, R_, pos_, num);
  Pushup(now_);
}
//=============================================================
int main() {
  scanf("%s", S + 1); n1 = strlen(S + 1);
  S[n1 + 1] = 'z' + 1;
  scanf("%s", S + n1 + 2); n = strlen(S + 1);
  SuffixSort();

  //正序枚举所有后缀,计算字典序 >sa_i 的 前半段的后缀的贡献。
  //当枚举到一个 后半段的后缀,仅用于更新 min(lcp)。
  //枚举到一个 前半段的后缀,用于更新 min(lcp),且需新插入一个后缀。
  //由于 lcp 可能为 0,线段树维护的区间加了偏移量 1。
  for (int i = 2; i <= n; ++ i) {
    //计算 lcp > height(i) 的 前半段后缀的数量,并将他们删除。
    ll num = Delete(1, 1, n + 1, height[i] + 2, n + 1); 
    Insert(1, 1, n + 1, height[i] + 1, num + (sa[i - 1] <= n1)); //插入被删除的后缀 与 新后缀。注意边界。
    if (sa[i] > n1 + 1) ans += sum[1]; //若枚举到一个 后半段后缀,计算贡献。 注意边界。
  }
  Build(1, 1, n); //清空线段树
  //倒序枚举所有后缀,计算字典序 >sa_i 的 前半段的后缀的贡献。
  for (int i = n; i >= 2; -- i) {
    ll num = Delete(1, 1, n + 1, height[i] + 2, n + 1);
    Insert(1, 1, n + 1, height[i] + 1, num + (sa[i] <= n1)); //注意边界
    if (sa[i - 1] > n1 + 1) ans += sum[1]; //注意边界
  }
  printf("%lld", ans);
  return 0;
}