洛谷 P2178 [NOI2015] 品酒大会


链接

P2178


题意

给出一个长为 \(n(n\le3\times 10^5)\) 的字符串 \(S\),和 \(S\) 每个位置的权值 \(a_i(|a_i|\le 1\times 10^9)\),需要统计:

对于长度 \(x\in[0,n-1]\),统计选择任意两个 \(S\) 的后缀其 LCP 长度 \(\ge x\) 的选择的方案数;并统计每种 \(x\) 的方案中选取的两个位置权值之积的最大值。


分析

因为有后缀和LCP,首先把 \(S\) 的后缀数组跑出来。

我们发现此时的 \(ht\) 数组把排序后的后缀分成了几段,在大小为 \(x\)\(ht\) 以内,一个连通块内的所有后缀的 LCP 长度都大于 \(x\),这启示我们需要从上往下拆分或从下往上合并。
拆分很难做,所以我们考虑合并,那么维护并查集,信息包括并查集大小,最大次大权值,最小次小权值(因为存在负权值)。然后考虑合并顺序,我们发现从下往上合并时,\(ht\) 是单调不升的,也就是说按 \(ht\) 从大到小依次进行合并。两边存在并查集就合并,不存在就独自作为一个并查集。另外,\(ht[1]\) 没有意义,直接跳过就好了。合并并查集的时候注意一下信息的变化以及当前答案的变化。


算法

先做 SA,然后按 \(ht\) 从大到小排序后依次合并并查集,维护信息即可。


代码

#include
using namespace std;
#define int long long
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-'0';c=getchar();}
	return p*f;
}
const int N=3e5+5;
string S;
int n,m,rk[N<<1],sa[N],ht[N];
struct llmmkk{
	int fi,se,ref;
}st[N<<1],tmp[N<<1];
int cnt[N];
inline void jsort(){
	for(int i=0;i<=n;i++)cnt[i]=0;
	for(int i=1;i<=n;i++)cnt[st[i].se]++;
	for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1];
	for(int i=n;i>0;i--)tmp[cnt[st[i].se]]=st[i],cnt[st[i].se]--;	
	for(int i=0;i<=n;i++)cnt[i]=0;
	for(int i=1;i<=n;i++)cnt[tmp[i].fi]++;
	for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1];
	for(int i=n;i>0;i--)st[cnt[tmp[i].fi]]=tmp[i],cnt[tmp[i].fi]--;
}
int ston[257];
inline void SA(){
	for(int i=0;ib.ht;}
}ts[N];
int fa[N],sum[N],max1[N],max2[N],min1[N],min2[N];
inline int getf(int x){return fa[x]==x?x:fa[x]=getf(fa[x]);}
int ans1[N],ans2[N],ans,anst=-0x7fffffffffffffff,last,vis[N],timess;
inline void updatemin(int x,int a){
	if(x<=min1[a])min2[a]=min1[a],min1[a]=x;
	else if(x=max1[a])max2[a]=max1[a],max1[a]=x;
	else if(x>max2[a])max2[a]=x;
}
inline void merge(int x,int y){
	int f1=getf(x),f2=getf(y);
	if(f1==f2)return ;
	ans+=sum[f1]+sum[f2]+sum[f1]*sum[f2]+1,
	fa[f2]=f1,sum[f1]+=sum[f2]+1;
	updatemax(max1[f2],f1),updatemax(max2[f2],f1);
	updatemin(min1[f2],f1),updatemin(min2[f2],f1);
	anst=max(anst,max1[f1]*max2[f1]);
	anst=max(anst,min1[f1]*min2[f1]);
}
signed main(){
	cin>>n>>S;SA();
	for(int i=1;i<=n;i++)
	cin>>w[i],ts[i].ht=ht[i],ts[i].o=i;
	sort(ts+1,ts+1+n);
	for(int i=1;i<=n;i++){
		int t=ts[i].ht,to=ts[i].o,f1,tt;
		if(to==1)continue;
		if(t!=last){ans1[timess]=ans,ans2[timess]=anst,timess++,last=t;}
		if(vis[to-1]&&vis[to+1])merge(to-1,to+1);
		else{
			if(!vis[to-1]&&!vis[to+1]){
				fa[to]=to,vis[to]=sum[to]=1,ans++,
				min2[to]=max1[to]=max(w[sa[to]],w[sa[to-1]]),
				min1[to]=max2[to]=min(w[sa[to]],w[sa[to-1]]);
				anst=max(anst,w[sa[to]]*w[sa[to-1]]);
			}
			else{
				if(vis[to-1])f1=getf(to-1),tt=w[sa[to]];
				else f1=getf(to+1),tt=w[sa[to-1]];			
				updatemax(tt,f1),updatemin(tt,f1);			
				sum[f1]++,ans+=sum[f1],vis[to]=1,fa[to]=f1;
				anst=max(anst,max1[f1]*max2[f1]);
				anst=max(anst,min1[f1]*min2[f1]);	
			}
		}				
	}
	int f1=getf(2);		
	anst=max(anst,max1[f1]*max2[f1]);
	anst=max(anst,min1[f1]*min2[f1]);
	ans1[timess]=ans,ans2[timess]=anst;
	for(int i=timess;i>0;i--)cout<