【题解】Asterisk Substrings Codeforces 1276F 后缀自动机 树链的并


第一道独立解决的Div1F,嘿嘿,幸好没看题解


把串分为以下几类

不包含star的串

太简单,略

star在最前面的串

star在最后面的串

单独一个star

答案++

单独一个空串

答案++

star在中间的串

注意到,假设star的位置是pos,实际上相当于选择一个右端点为pos-1的串s1,再选择一个左端点为pos+1的串s2,问这样的pair(s1,s2)有多少个

也就是选两个原串的子串,并且这两个子串要满足上面那个条件,问方案数

对原串的n-2个字符SAM,称为sam

对原串的n-2个字符倒过来再建SAM,称为rsam

注意到,SAM上每个点表示的本质不同子串数量是len[u] - len[pa[u]],其中len是点u所表示字符串的最长长度,pa是点u在后缀树上的父亲,记这个值为val[u]

也就是说,问题变成了:枚举sam里面的一个点u,枚举rsam里面的一个点v,如果vend_pos集合存在一个数字,等于uend_pos集合里面的某个数字+2,那么ans += val[u] * val[v]

考虑sam里面的每个点u,假设uend_pos集合是{a1, a2, a3, ..., ak},那么在rsam里面,有哪些点可以和u产生贡献?所有end_pos集合包含某一个ai+2的点可以和u产生贡献,这在rsam的后缀树上,是k条树链的并

sam的后缀树上跑DSU on Tree,维护上述end_pos集合,并时刻维护集合中所有点的树链的并

总复杂度两个log

#include 

using namespace std;
typedef long long ll;
const int N = 200010;
int _w;

struct SAM {
	int ch[N][26];
	int len[N];
	int pa[N];
	int idx;
	
	void init() {
		memset(ch, 0, sizeof ch);
		memset(len, 0, sizeof len);
		memset(pa, 0, sizeof pa);
		idx = 1;
		pa[0] = -1;
	}
	int append( int p, int c ) {
		int np = idx++;
		len[np] = len[p] + 1;
		while( p != -1 && !ch[p][c] )
			ch[p][c] = np, p = pa[p];
		if( p == -1 ) pa[np] = 0;
		else {
			int q = ch[p][c];
			if( len[q] == len[p] + 1 ) pa[np] = q;
			else {
				int nq = idx++;
				memcpy(ch[nq], ch[q], sizeof ch[nq]);
				len[nq] = len[p] + 1;
				pa[nq] = pa[q];
				pa[q] = pa[np] = nq;
				while( p != -1 && ch[p][c] == q )
					ch[p][c] = nq, p = pa[p];
			}
		}
		return np;
	}
};

int n;
char str[N];
SAM sam, rsam;

ll solve_origin() {
	sam.init();
	int last = 0;
	for( int i = 1; i <= n; ++i )
		last = sam.append(last, str[i] - 'a');
	ll ans = 0;
	for( int i = 1; i < sam.idx; ++i )
		ans += sam.len[i] - sam.len[sam.pa[i]];
	return ans;
}

ll solve_before() {
	sam.init();
	int last = 0;
	for( int i = 2; i <= n; ++i )
		last = sam.append(last, str[i] - 'a');
	ll ans = 0;
	for( int i = 1; i < sam.idx; ++i )
		ans += sam.len[i] - sam.len[sam.pa[i]];
	return ans;
}

ll solve_after() {
	sam.init();
	int last = 0;
	for( int i = 1; i <= n-1; ++i )
		last = sam.append(last, str[i] - 'a');
	ll ans = 0;
	for( int i = 1; i < sam.idx; ++i )
		ans += sam.len[i] - sam.len[sam.pa[i]];
	return ans;
}

struct Graph {
	int head[N], nxt[N], to[N], eid;
	void init() {
		eid = 0;
		memset(head, -1, sizeof head);
	}
	void link( int u, int v ) {
		to[eid] = v, nxt[eid] = head[u], head[u] = eid++;
	}
};
Graph g, rg;

namespace HLD {
	int dfn[N], dfnc, top[N], dep[N];
	int pa[N], sz[N], son[N], val[N];
	int rdfn[N];
	
	void dfs1( int u, int fa, int d ) {
		sz[u] = 1, dep[u] = d, pa[u] = fa;
		val[u] = rsam.len[u];
		for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
			int v = rg.to[i];
			dfs1(v, u, d+1);
			sz[u] += sz[v];
			if( son[u] == -1 || sz[v] > sz[son[u]] )
				son[u] = v;
		}
	}
	void dfs2( int u, int tp ) {
		dfn[u] = ++dfnc, top[u] = tp;
		rdfn[dfnc] = u;
		if( son[u] != -1 )
			dfs2( son[u], tp );
		for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
			int v = rg.to[i];
			if( v != son[u] )
				dfs2(v, v);
		}
	}
	void init() {
		memset(son, -1, sizeof son);
		dfs1(0, -1, 1);
		dfs2(0, 0);
	}
	int lca( int u, int v ) {
		while( top[u] != top[v] ) {
			if( dep[top[u]] < dep[top[v]] )
				swap(u, v);
			u = pa[top[u]];
		}
		return dep[u] < dep[v] ? u : v;
	}
}

int mark[N], rmark[N], rmark2nod[N];
ll solve_ans = 0, now = 0;
set st;

void ins_node( int u ) {
	u = mark[u];
	if( !u ) return;
	u = rmark2nod[u+2];
	u = HLD::dfn[u];
	if( st.empty() ) {
		st.insert(u);
		u = HLD::rdfn[u];
		now += HLD::val[u];
	} else {
		auto after = st.lower_bound(u);
		auto before = after;
		--before;
		if( after == st.end() ) {
			int L = *before;
			L = HLD::rdfn[L];
			u = HLD::rdfn[u];
			int lca = HLD::lca(L, u);
			now -= HLD::val[lca];
			now += HLD::val[u];
			u = HLD::dfn[u];
			st.insert(u);
		} else if( after == st.begin() ) {
			int R = *after;
			R = HLD::rdfn[R];
			u = HLD::rdfn[u];
			int lca = HLD::lca(R, u);
			now -= HLD::val[lca];
			now += HLD::val[u];
			u = HLD::dfn[u];
			st.insert(u);
		} else {
			int L = *before;
			int R = *after;
			L = HLD::rdfn[L];
			R = HLD::rdfn[R];
			now += HLD::val[HLD::lca(L, R)];
			u = HLD::rdfn[u];
			now -= HLD::val[HLD::lca(L, u)];
			now -= HLD::val[HLD::lca(R, u)];
			now += HLD::val[u];
			u = HLD::dfn[u];
			st.insert(u);
		}
	}
}

void ins_tree( int u ) {
	ins_node(u);
	for( int i = g.head[u]; ~i; i = g.nxt[i] )
		ins_tree( g.to[i] );
}

int sz[N], son[N];

void init_sack( int u ) {
	sz[u] = 1, son[u] = -1;
	for( int i = g.head[u]; ~i; i = g.nxt[i] ) {
		int v = g.to[i];
		init_sack(v);
		sz[u] += sz[v];
		if( son[u] == -1 || sz[v] > sz[son[u]] )
			son[u] = v;
	}
}

void sack( int u, bool clr ) {
	// printf( "u = %d\n", u );
	for( int i = g.head[u]; ~i; i = g.nxt[i] )
		if( g.to[i] != son[u] )
			sack( g.to[i], true );
	if( son[u] != -1 )
		sack( son[u], false );
	for( int i = g.head[u]; ~i; i = g.nxt[i] )
		if( g.to[i] != son[u] )
			ins_tree( g.to[i] );
	ins_node(u);
	// printf( "u = %d, now = %lld\n", u, now );
	if( u )
		solve_ans += 1LL * now * (sam.len[u] - sam.len[sam.pa[u]]);
	if( clr ) st.clear(), now = 0;
}

ll solve() {
	sam.init();
	int last = 0;
	for( int i = 1; i <= n-2; ++i )
		last = sam.append(last, str[i] - 'a');
	g.init();
	for( int i = 1; i < sam.idx; ++i )
		g.link( sam.pa[i], i );
	last = 0;
	for( int i = 1; i <= n-2; ++i ) {
		last = sam.ch[last][str[i] - 'a'];
		mark[last] = i;
	}
	
	rsam.init();
	last = 0;
	for( int i = n; i >= 3; --i )
		last = rsam.append(last, str[i] - 'a');
	rg.init();
	for( int i = 1; i < rsam.idx; ++i )
		rg.link( rsam.pa[i], i );
	last = 0;
	for( int i = n; i >= 3; --i ) {
		last = rsam.ch[last][str[i] - 'a'];
		rmark[last] = i;
		rmark2nod[i] = last;
	}
	
	HLD::init();
	init_sack(0);
	sack(0, false);
	return solve_ans;
}

int main() {
	_w = scanf( "%s", str+1 );
	n = (int)strlen(str+1);
	ll ans = 0;
	ans += solve_origin();
	// printf( "after origin = %lld\n", ans );
	if( n >= 2 ) {
		ans += solve_before();
		ans += solve_after();
	}
	// printf( "before after = %lld\n", ans );
	if( n >= 3 ) {
		ans += solve();
	}
	printf( "%lld\n", ans+2 );
	return 0;
}