CF666E Forensic Examination——SAM+线段树合并+倍增


RemoteJudge

题目大意

给你一个串\(S\)以及一个字符串数组\(T[1...m]\)\(q\)次询问,每次问\(S\)的子串\(S[p_l...p_r]\)\(T[l...r]\)中的哪个串里的出现次数最多,并输出出现次数。
如有多解输出最靠前的那一个。

思路

第一次见到在\(parent tree\)上线段树合并的题,感觉好妙
先对\(T\)建一个广义后缀自动机,考虑对\(SAM\)上的每一个结点建一颗线段树,值域为\([1,m]\),维护出现次数最多的串的位置和次数。又因为\(endpos\)集合(好像也叫\(right\)集合)有这么一个性质:一个结点的\(endpos\)集合即为其在\(parent\ tree\)上子结点的并集,所以我们在建树时只需要上一个线段树合并即可。
上面的那个思路貌似是个套路?
然后来处理询问,显然我们只需要在\(S[p_l...p_r]\)对应的结点的线段树上查\(l-r\)的最大值就行了,但如果直接拿\(S[p_l...p_r]\)\(SAM\)上匹配,复杂度绝壁不对QwQ。于是我们考虑先把整个\(S\)\(SAM\)上匹配,需要查哪个子串时通过跳\(suflink\)来找。具体一下,就是对于\(S\)的一个前缀\(S[1...j]\),如果它最后匹配到了结点\(u\),匹配的长度为\(len\),然后我们要查的子串是\(S[i...j]\),就从\(u\)开始跳\(suflink\)直到一个\(maxlen\)大于等于\(j-i+1\)且深度最小的结点,记其为\(v\),要查的就是\(v\)那棵线段树的答案
最后发现跳\(suflink\)的过程可以用倍增来优化,然后就没了
吐槽1.为什么我写离线的就会\(WA\),在线的就过了
吐槽2.下午三点多写完,然后\(CF\)

咕到了六点多,然后交了一发,\(WA\)了,我...

#include 
#include  
#include   
#include   
#include    
#include    
#include    
#include     
#include     
#include     
#include       
#include       

using namespace std;

#define ull unsigned long long
#define pii pair
#define uint unsigned int
#define mii map
#define lbd lower_bound
#define ubd upper_bound
#define INF 0x3f3f3f3f
#define IINF 0x3f3f3f3f3f3f3f3fLL
#define vi vector
#define ll long long
#define mp make_pair
#define pb push_back
#define re register
#define il inline

#define MAXS 500000
#define M 50000
#define Q 500000
#define MAXT 100000
#define LIM 16

char S[MAXS+5], T[MAXT+5];
int n, m, q;
int nxt[26][2*MAXT+5], maxlen[2*MAXT+5], link[2*MAXT+5], in[2*MAXT+5], nid1, lst;
int nid2, root[2*MAXT+5], ch[2][160*MAXT+5];
vi G[2*MAXT+5];
int f[2*MAXT+5][LIM+1];
int tar[MAXS+5], ml[MAXS+5];

struct Data {
  int w, pos;
  friend Data operator + (Data lhs, Data rhs) {
    if(lhs.w > rhs.w) return lhs;
    else {
      if(lhs.w == rhs.w && lhs.pos < rhs.pos) return lhs;
      else return rhs;
    }
  }
  bool operator < (const Data &rhs) const {
    return w == rhs.w ? pos > rhs.pos : w < rhs.w;
  }
}nodes[160*MAXT+5];

void init() {
  nid1 = lst = 1;
  nid2 = 0;
}

void pushup(int o) {
  nodes[o] = nodes[ch[0][o]]+nodes[ch[1][o]];
}

void add(int &u, int l, int r, int x) {
  if(!u) u = ++nid2;
  if(l == r) {
    nodes[u] = Data{++nodes[u].w, nodes[u].pos = l};
    return ;
  }
  int mid = (l+r)>>1;
  if(x <= mid) add(ch[0][u], l, mid, x);
  else add(ch[1][u], mid+1, r, x);
  pushup(u);
}

int merge(int x, int y, int l, int r) {
  if(!x || !y) return x | y;
  int now = ++nid2;
  if(l == r) {
    nodes[now] = Data{nodes[x].w+nodes[y].w, nodes[x].pos};
    return now;
  }
  int mid = (l+r)>>1;
  ch[0][now] = merge(ch[0][x], ch[0][y], l, mid);
  ch[1][now] = merge(ch[1][x], ch[1][y], mid+1, r);
  pushup(now);
  return now;
}

Data query(int o, int l, int r, int L, int R) {
  if(!o) return Data{0, 0};
  if(L <= l && r <= R) return nodes[o];
  int mid = (l+r)>>1;
  Data ret{0, 0};
  if(L <= mid) ret = ret+query(ch[0][o], l, mid, L, R);
  if(R > mid) ret = ret+query(ch[1][o], mid+1, r, L, R);
  return ret;
}

void extend(int c, int id) {
  int cur = ++nid1;
  maxlen[cur] = maxlen[lst]+1;
  add(root[cur], 1, m, id);
  while(lst && !nxt[c][lst]) nxt[c][lst] = cur, lst = link[lst];
  if(!lst) link[cur] = 1;
  else {
    int p = lst, q = nxt[c][lst];
    if(maxlen[q] == maxlen[p]+1) link[cur] = q;
    else {
      int clone = ++nid1;
      maxlen[clone] = maxlen[p]+1;
      link[clone] = link[q], link[q] = link[cur] = clone;
      for(int i = 0; i < 26; ++i) nxt[i][clone] = nxt[i][q];
      while(p && nxt[c][p] == q) nxt[c][p] = clone, p = link[p];
    }
  }
  lst = cur;
}

void insert(int id) {
  int t = strlen(T+1);
  lst = 1;
  for(int i = 1; i <= t; ++i) extend(T[i]-'a', id);
}

void build(int u, int fa) {
  f[u][0] = fa;
  for(int i = 1; i <= LIM; ++i) f[u][i] = f[f[u][i-1]][i-1];
  for(int i = 0, v; i < G[u].size(); ++i) {
    v = G[u][i];
    build(v, u);
    root[u] = merge(root[u], root[v], 1, m);
  }
}

void pre() {
  n = strlen(S+1);
  int u = 1, len = 0;
  for(int i = 1; i <= n; ++i) {
    if(nxt[S[i]-'a'][u]) u = nxt[S[i]-'a'][u], len++;
    else {
      while(u && !nxt[S[i]-'a'][u]) u = link[u];
      if(!u) u = 1, len = 0;
      else len = maxlen[u]+1, u = nxt[S[i]-'a'][u];
    }
    tar[i] = u, ml[i] = len;
  }
}

int main() {
  scanf("%s%d", S+1, &m);
  init();
  for(int i = 1; i <= m; ++i) scanf("%s", T+1), insert(i);
  for(int i = 2; i <= nid1; ++i) G[link[i]].pb(i);
  build(1, 0);
  pre();
  scanf("%d", &q);
  for(int i = 1, l, r, pl, pr, L; i <= q; ++i) {
    scanf("%d%d%d%d", &l, &r, &pl, &pr);
    L = pr-pl+1;
    if(L > ml[pr]) printf("%d 0\n", l);
    else {
      int u = tar[pr];
      for(int k = LIM; ~k; --k) if(maxlen[f[u][k]] >= L) u = f[u][k];
      Data ret = query(root[u], 1, m, l, r);
      if(ret.w == 0) ret.pos = l;
      printf("%d %d\n", ret.pos, ret.w);
    }
  }
  return 0;
}