CF1276F-Asterisk Substrings【SAM,线段树合并】


正题

题目链接:https://www.luogu.com.cn/problem/CF1276F


题目大意

给出一个长度为\(n\)的字符串\(S\),现在依次进行如下操作

  1. 取出\(S\)的一个子串\(T\)
  2. \(T\)中的一个字符替换成\(*\)号(也可以不替换)

求最后有多少种不同的\(T\)


解题思路

发现最终其实只有4种情况,\(T,T*,*T,T_1*T_2\)

前面三种很好记录,主要考虑最后一种。

对于\(T_1\)来说,同一个\(endpos\)等价类中的子串对应的\(T_2\)数量应该也是相同的。

那我们肯定是先建一个\(SAM\)这样就可以知道每个\(endpos\)等价类了。

那么考虑怎么统计\(T_2\)的数量,其实对应\(i\)在某个\(endpos\)里,那么我们就只需要考虑从\(i+2\)这些位置开始的不同串的数量。

也就是确定起始位置的子串,这提醒我们反着再建立一个\(SAM\),然后把所有\(i+2\)位置的点标记了,这些点和根在\(fail\)树上的虚树路径长度和就是我们要知道的答案。

那么怎么维护这个东西呢,我们在正着的\(SAM\)上每个点维护一个线段树,然后统计反着的\(SAM\)上的链并长度,这个显然是可以用线段树合并的。

时间复杂度:\(O(n\log^2 n)\)(如果肯写\(O(1)\)LCA的话可以做到\(O(n\log n)\)


code

#include
#include
#include
#define ll long long
using namespace std;
const ll N=2e5+10,M=N<<5;
struct SAM{
	ll cnt,last,len[N],pos[N],fa[N],ch[N][26];
	void Ins(ll c,ll id){
		ll p=last,np=last=++cnt;
		len[np]=len[p]+1;pos[id]=np;
		for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
		if(!p)fa[np]=1;
		else{
			ll q=ch[p][c];
			if(len[p]+1==len[q])fa[np]=q;
			else{
				ll nq=++cnt;len[nq]=len[p]+1;
				memcpy(ch[nq],ch[q],sizeof(ch[nq]));
				fa[nq]=fa[q];fa[q]=fa[np]=nq;
				for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
			}
		}
		return;
	}
}Suf,Pre;
struct node{
	ll to,next;
}a[N];
ll n,m,tot,cnt,ls[N],siz[N],dep[N],rt[N];
ll dfn[N],rfn[N],son[N],top[N],ans=2;
char s[N];
void addl(ll x,ll y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
void dfs1(ll x){
	siz[x]=1;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;dep[y]=dep[x]+1;
		dfs1(y);siz[x]+=siz[y];
		if(siz[y]>siz[son[x]])son[x]=y;
	}
	return;
}
void dfs2(ll x){
	dfn[++cnt]=x;rfn[x]=cnt;
	if(son[x]){
		top[son[x]]=top[x];
		dfs2(son[x]);
	}
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==son[x])continue;
		top[y]=y;dfs2(y);
	}
	return;
}
ll LCA(ll x,ll y){
	while(top[x]!=top[y]){
		if(dep[top[x]]>1;
		if(pos<=mid)Change(ls[x],L,mid,pos);
		else Change(rs[x],mid+1,R,pos);
		Merge(x);return;
	}
	ll Merge(ll x,ll y,ll L,ll R){
		if(!x||!y)return x|y;
		ll mid=(L+R)>>1;
		ls[x]=Merge(ls[x],ls[y],L,mid);
		rs[x]=Merge(rs[x],rs[y],mid+1,R);
		Merge(x);return x;
	}
}T;
void solve(ll x){
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		solve(y);rt[x]=T.Merge(rt[x],rt[y],1,cnt);
	}
	ans+=T.w[rt[x]]*(Pre.len[x]-Pre.len[Pre.fa[x]]);
	return;
}
signed main()
{
	scanf("%s",s+1);n=strlen(s+1);
	Pre.last=Pre.cnt=Suf.last=Suf.cnt=1;
	for(ll i=n;i>1;i--)Suf.Ins(s[i]-'a',i);
	for(ll i=1;i