【题解】CF700E Cool Slogans


自己做了个 *3300 出来,感觉非常的好,写一写题解。

题目传送门:luogu,CF。

单字符串的子串相关问题,考虑先套上一个 SAM。

考虑 \(s_p\) 对应的 SAM 上的节点 \(x\),因为 \(s_p\)\(s_{p+1}\) 的子串,所以最优情况下 \(s_{p+1}\) 一定为 \(s_p\) 的后缀,即 \(s_{p+1}\) 对应的节点一定在 \(x\) 的 fail 树子树中。

但是,SAM 中一个节点对应多个 endpos 集合相同的串,当产生一个新的 \(s_p\) 时应该取对应节点的一个什么长度的串呢?发现对于节点 \(x\) 对应的所有长度的串,因为它们 endpos 集合相同,所以每个长度的串可产生的 \(s_{p+1}\) 所对应的节点 \(y\) 的集合都一样(考虑 \(x\) 对应的 \(s_p\) 增长后如果 \(s_{p+1}\) 不能再对应某一个节点 \(y\),那说明增长前后的串 endpos 一定不同),而又要对于不同的 \(s_{p-1}\) 接纳进来的最多,所以长度直接取 \(maxlen_x\) 一定不会使情况变劣(为方便,下文中将 \(x\) 节点对应的串中长为 \(maxlen_x\) 的串成为 \(mls_x\))。

那么考虑一个树形 dp,\(dp_x\) 表示 \(x\) 子树中按上述构造方法能生成的最长的 \(s\) 序列,方程为 \(dp_x=max \{dp_y|y\in subtree_x \wedge show(mls_x,mls_y)\geq 2\}\)(其中 \(show(s1,s2)\) 表示 \(s1\)\(s2\) 中的出现次数)。

根据定义,对于一个 \(x\) 满足 \(show(mls_p,mls_x)\geq 2\)\(p\) 一定是它到根路径上深度浅的一段,设其中最深的点为 \(g_x\),可以发现 \(g_{fa_x}\) 一定为 \(g_x\) 的祖先节点,所以预处理出每个点的 \(g_x\) 就可以很方便的 dp 了。

预处理的一个麻烦之处就是判断要 \(show(mls_x,mls_y)\geq 2\) 是否为真,套路地线段树合并维护出每个点的 endpos 集合,对于 \(endpos_y\) 中的一个位置 \(p\),其所对应的串即为 \([p-maxlen_y+1,p]\) 这个区间。因为 \(x\)\(y\) 祖先,所以 \(p\) 一定也在 \(endpos_x\) 中,记 \(p\) 的前驱(在 \(endpos_x\) 中)为 \(q\),若 \(q\) 所对应的串 \([q-maxlen_x+1,q]\)\([p-maxlen_y+1,p]\) 包含,则说明 \(show(mls_x,mls_y)\geq 2\) 成立,否则不成立。因为所对应的串本质相同,所以 \(p\)\(endpos_y\) 中的任意元素均可,实现中取第一个比较方便。

代码见下,码风不很能看,勿喷。

#include
using namespace std;
#define re register

const int N=404040,M=N*80;
int n,tot=1,las=1,ch[N][26],ml[N],fa[N],fp[N],ton[N],tp[N],g[N],rt[N];
int cnt,sum[M],ls[M],rs[M],dp[N];

inline void ext(int c,int k){
	int x=++tot,p=las,q;
	las=x,ml[x]=ml[p]+1,fp[x]=k;
	while(!ch[p][c]&&p) ch[p][c]=x,p=fa[p];
	if(!p) return fa[x]=1,void();
	if(ml[q=ch[p][c]]==ml[p]+1) return fa[x]=q,void();
	ml[++tot]=ml[p]+1,fa[tot]=fa[q],fa[q]=fa[x]=tot;
	memcpy(ch[tot],ch[q],sizeof ch[tot]);
	while(ch[p][c]==q) ch[p][c]=tot,p=fa[p];
}

inline void topo(){
	for(re int i=1;i<=tot;++i) ++ton[ml[i]];
	for(re int i=tot-1;i;--i) ton[i]+=ton[i+1];
	for(re int i=1;i<=tot;++i) tp[ton[ml[i]]--]=i;
}

inline void upd(int &k,int l,int r,int x){
	sum[k=++cnt]=1;
	if(l==r) return ;
	int d=(l+r)>>1;
	if(x<=d) upd(ls[k],l,d,x);
	else upd(rs[k],d+1,r,x);
}

inline int mrg(int k1,int k2,int l,int r){
	if(!k1||!k2) return k1|k2;
	int k=++cnt,d=(l+r)>>1;sum[k]=sum[k1]+sum[k2];
	if(l!=r) ls[k]=mrg(ls[k1],ls[k2],l,d),rs[k]=mrg(rs[k1],rs[k2],d+1,r);
	return k;
}

inline void init(){
	for(re int i=1;i<=tot;++i) if(fp[i]!=0x3f3f3f3f) upd(rt[i],1,n,fp[i]);
	for(re int i=1;i<=tot;++i) g[i]=fa[i];
	for(re int i=1,x=tp[i],y=fa[x];i>1,res=pre(rs[k],d+1,r,x);
	return res?res:pre(ls[k],l,d,x);
}

inline bool check(int x,int y){return fp[y]-ml[y]<=pre(rt[x],1,n,fp[y]-1)-ml[x];}
inline void solve(){
	for(re int i=1,x=tp[i],y=fa[x];iml[g[x]]) g[y]=g[x];
	}
	for(re int i=1,x=tp[i];i