「AHOI2013」 差异


知识点: SA,线段树,单调栈

原题面 Loj Luogu


题意简述

给定一长度为 \(n\) 的字符串 \(S\),令 \(T_i\) 表示从第 \(i\) 个字符开始的后缀,求:

\[\sum_{1\le i

\(\operatorname{len}(a)\) 表示字符串 \(a\) 的长度,\(\operatorname{lcp}(a,b)\) 表示字符串 \(a,b\) 的最长公共前缀。

分析题意

SA

化下式子:

\[\begin{aligned} ans &= \sum_{1\le i

考虑如何快速求后一半,即所有 \(\operatorname{lcp}\) 之和。

发现有下列等价关系:

\[\sum_{1\le i

\(\operatorname{lcp}(a,b) = \operatorname{lcp}(b,a)\),枚举 \(sa\) 一定不会重也不会漏。

类似这题的套路:「HAOI2016」找相同字符,
考虑枚举 \(sa_j\),用权值线段树维护 \(sa_i (i 的不同长度的 \(\operatorname{lcp}(sa_i, sa_j)\) 的数量。

引理:\(\forall 1\le i < j\le n,\, \operatorname{lcp}(sa_i,sa_j) = \min\limits_{k=i+1}^j\{\operatorname{height_k}\}\)
模拟引理,当 \(j+1\) 时将权值线段树中所有 \(>\operatorname{height}_{j+1}\) 的元素删除,并添加相同个数个 元素 \(\operatorname{height}_{j+1}\)
添加一个 \(\operatorname{height}_{j+1}\),代表新增的 \(sa_j\) 的贡献。
贡献求和即可。

总复杂度 \(O(n\log n)\)


线段树太傻逼了,考虑单调栈。
发现有下列等价关系:

\[\sum_{1\le i

即求 \(\operatorname{height}\) 每个区间的区间最小值之和。
经典问题,考虑 \(\operatorname{height}\) 作为最小值的区间的最大 左/右端 点,可单调栈维护。
答案即 \(\sum\limits_{i=2}^{n}(i-l_i)\times (r_i-i)\times \operatorname{height}_i\)

注意区间长度不能为 1。


后缀树

考虑原始式子:

\[\sum_{1\le i

这玩意长得很树上差分。
对于 \(S\) 的后缀树,\(\operatorname{lcp}\) 即为后缀树的 \(\operatorname{lca}\)
上式等价于后缀树上所有后缀之间的距离。

对反串建 SAM,即得后缀树。
题目转化为:树上某一点是多少 表示后缀的节点 的 \(\operatorname{lca}\) 再乘上 \(dep\)
记录子树大小, DP 实现即可。


爆零小技巧:线段树不一定只开 4 倍空间,当 \(n\) 到达 \(5\times 10^5\) 级别一定要小心。


代码实现


SA + 单调栈

这写法挺神仙的,感觉要重学单调栈。

//
/*
By:Luckyblock
*/
#include 
#include 
#include 
#include 
#include 
#define ll long long
const int kMaxn = 5e5 + 10;
//=============================================================
char S[kMaxn];
int n, m, sa[kMaxn], rk[kMaxn << 1], oldrk[kMaxn << 1], height[kMaxn];
int cnt[kMaxn], id[kMaxn], rkid[kMaxn];
int top, st[kMaxn], l[kMaxn], r[kMaxn];
//=============================================================
inline int read() {
  int f = 1, w = 0; char ch = getchar();
  for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void GetMax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void GetMin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
int cmp(int x, int y, int w) {
  return oldrk[x] == oldrk[y] && 
         oldrk[x + w] == oldrk[y + w];
}
void GetHeight() {
  for (int i = 1, k = 0; i <= n; ++ i) {
    if (rk[i] == 1) k = 0;
    else {
      if (k > 0) k --;
      int j = sa[rk[i] - 1];
      while (i + k <= n && j + k <= n && 
             S[i + k] == S[j + k]) {
        ++ k;
      }
    }
    height[rk[i]] = k;
  }
}
void SuffixSort() {
  n = strlen(S + 1);
  m = 1010;
  for (int i = 1; i <= n; ++ i) cnt[rk[i] = S[i]] ++;
  for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
  for (int i = n; i; -- i) sa[cnt[rk[i]] --] = i;
  
  for (int p, w = 1; w < n; w <<= 1) {
    p = 0;
    for (int i = n; i > n - w; -- i) id[++ p] = i;
    for (int i = 1; i <= n; ++ i) {
      if (sa[i] > w) id[++ p] = sa[i] - w;
    }
    memset(cnt, 0, sizeof (cnt));
    for (int i = 1; i <= n; ++ i) cnt[rkid[i] = rk[id[i]]] ++;
    for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
    for (int i = n; i; -- i) sa[cnt[rkid[i]] --] = id[i]; 
    
    std :: swap(rk, oldrk);
    m = 0;
    for (int i = 1; i <= n; ++ i) {
      m += (cmp(sa[i], sa[i - 1], w) ^ 1);
      rk[sa[i]] = m;
    }
  }
  GetHeight();
}
//=============================================================
int main() {
  scanf("%s", S + 1);
  SuffixSort();
  ll ans = 1ll * ((n - 1ll) * n / 2ll) * (n + 1ll) ;
  st[(top = 1)] = 1;
	for (int i = 2; i <= n; ++ i) {
		while (top && height[st[top]] > height[i]) {
		  r[st[top]] = i;
		  top --;
    }
		l[i] = st[top];
		st[++ top] = i;
	} 
  while (top) r[st[top --]] = n + 1;
  for (int i = 2; i <= n; ++ i) {
    ans -= 2ll * (i - l[i]) * (r[i] - i) * height[i]; 
  }
  printf("%lld", ans);
  return 0;
}

SA + 线段树

//知识点:SA
/*
By:Luckyblock
*/
#include 
#include 
#include 
#include 
#include 
#define ll long long
#define lson (now_<<1)
#define rson (now_<<1|1)
const int kMaxn = 5e5 + 10;
//=============================================================
char S[kMaxn];
int n, m, sa[kMaxn], rk[kMaxn << 1], oldrk[kMaxn << 1], height[kMaxn];
int cnt[kMaxn], id[kMaxn], rkid[kMaxn];
ll size[kMaxn << 3], sum[kMaxn << 3];
bool tag[kMaxn << 3];
//=============================================================
inline int read() {
  int f = 1, w = 0; char ch = getchar();
  for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void GetMax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void GetMin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
int cmp(int x, int y, int w) {
  return oldrk[x] == oldrk[y] && 
         oldrk[x + w] == oldrk[y + w];
}
void GetHeight() {
  for (int i = 1, k = 0; i <= n; ++ i) {
    if (rk[i] == 1) k = 0;
    else {
      if (k > 0) k --;
      int j = sa[rk[i] - 1];
      while (i + k <= n && j + k <= n && 
             S[i + k] == S[j + k]) {
        ++ k;
      }
    }
    height[rk[i]] = k;
  }
}
void SuffixSort() {
  n = strlen(S + 1);
  m = 1010;
  for (int i = 1; i <= n; ++ i) cnt[rk[i] = S[i]] ++;
  for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
  for (int i = n; i; -- i) sa[cnt[rk[i]] --] = i;
  
  for (int p, w = 1; w < n; w <<= 1) {
    p = 0;
    for (int i = n; i > n - w; -- i) id[++ p] = i;
    for (int i = 1; i <= n; ++ i) {
      if (sa[i] > w) id[++ p] = sa[i] - w;
    }
    memset(cnt, 0, sizeof (cnt));
    for (int i = 1; i <= n; ++ i) cnt[rkid[i] = rk[id[i]]] ++;
    for (int i = 1; i <= m; ++ i) cnt[i] += cnt[i - 1];
    for (int i = n; i; -- i) sa[cnt[rkid[i]] --] = id[i]; 
    
    std :: swap(rk, oldrk);
    m = 0;
    for (int i = 1; i <= n; ++ i) {
      m += (cmp(sa[i], sa[i - 1], w) ^ 1);
      rk[sa[i]] = m;
    }
  }
  GetHeight();
}
void Pushdown(int now_) {
  tag[lson] = tag[rson] = true;
  size[lson] = size[rson] = 0;
  sum[lson] = sum[rson] = 0;
  tag[now_] = false;
}
void Pushup(int now_) {
  size[now_] = size[lson] + size[rson];
  sum[now_] = sum[lson] + sum[rson];
}
ll Delete(int now_, int L_, int R_, int ql_, int qr_) {
  if (ql_ <= L_ && R_ <= qr_) {
    ll ret = size[now_];
    tag[now_] = true;
    size[now_] = sum[now_] = 0ll;
    return ret;
  }
  if(tag[now_]) Pushdown(now_);
  int mid = (L_ + R_) >> 1;
  ll ret = 0ll;
  if (ql_ <= mid) ret += Delete(lson, L_, mid, ql_, qr_);
  if (qr_ > mid) ret += Delete(rson, mid + 1, R_, ql_, qr_);
  Pushup(now_);
  return ret;
}
void Insert(int now_, int L_, int R_, int pos_, ll num) {
  if (! num) return ;
  if (L_ == R_) {
    size[now_] += num;
    sum[now_] += 1ll * num * (L_ - 1ll);
    return ;
  }
  if (tag[now_]) Pushdown(now_);
  int mid = (L_ + R_) >> 1;
  if (pos_ <= mid) Insert(lson, L_, mid, pos_, num);
  else Insert(rson, mid + 1, R_, pos_, num);
  Pushup(now_);
}
//=============================================================
int main() {
  scanf("%s", S + 1);
  SuffixSort();
  ll ans = 1ll * ((n - 1ll) * n / 2ll) * (n + 1ll) ;
  for (int j = 2; j <= n; ++ j) {
    ll num = Delete(1, 1, n + 1, height[j] + 2, n + 1);
    Insert(1, 1, n + 1, height[j] + 1, num + 1);
    ans -= 2ll * sum[1];
  }
  printf("%lld", ans);
  return 0;
}

后缀树

咕咕咕,建议 Lg题解。