学习笔记——AC自动机


AC自动机

题单

原理

AC自动机可以理解为在字典树连接失配 \(fail\) 指针后构成的字典图,\(fail\) 指针指向一个与当前状态公共后缀最长一个状态,用来解决多个模式串与文本串的匹配问题。

实现

建Trie树

不多加赘述,直接上代码。

inline void insert(char *s){
    int u=0;
    for(int i=1;s[i];i++){
        if(!tr[u][s[i]-'a']){
            tr[u][s[i]-'a']=++tot;
        }
        u=tr[u][s[i]-'a'];
    }
    mark[u]++;//标记为一个模式串
}

构建 \(fail\) 指针

\(bfs\) 来构建 \(fail\) 指针,首先将根节点下的节点放入队列中。

可以发现对于每个节点 \(u\) 来说,它的失配指针以及它失配指针的失配指针……,这些状态的后缀都是相同的,只不过匹配的长度越来越小,于是当存在 \(tr(u,i)\) 时,只需要将 \(fail(tr(u,i))\) 指向 \(u\) 的失配指针下面的状态,也就是 \(tr(fail(u),i)\),若 \(tr((fail(u)),i)\) 不存在,我们就继续找 \(fail(u)\) 的失配指针,直到根节点为止。

queue q;
inline void build(){
    for(int i=0;i<26;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int u=q.front();
        q.pop();
        for(int i=0;i<26;i++){
            if(tr[u][i]){
                fail[tr[u][i]]=tr[fail[u]][i];
                q.push(tr[u][i]);
            }
            else tr[u][i]=tr[fail[u]][i];
        }
    }
}

但是实现起来却与刚刚说的有所不同,这里如果不存在 \(tr(u,i)\) 我们直接令其为 \(tr(fail(u),i)\),因为发现访问到 \(tr(u,i)\) 且不存在时,我们便会到 \(fail(u)\) 处访问 \(tr(fail(u),i)\)。而这个 \(tr(fail(u),i)\) 若存在就是其失配指针,若不存在则已经在之前的 \(bfs\) 中连上 \(tr(fail(fail(u)),i)\) 或根节点了。这样会节省访问的时间。

查找匹配的模式串个数

枚举文本串的每一位,跳失配指针直到指向根节点(这个过程中所有以访问节点 \(u\) 为结尾的状态都可能单独作为一个模式串,换言之也就访问了所有以当前文本串字符结尾的状态),访问后记录答案并打上标记。

inline int query(char *s){
    int u=0,res=0;
    for(int i=1;s[i];i++){
        u=tr[u][s[i]-'a'];
        for(int j=u;j&&mark[j]!=-1;j=fail[j]){
            res+=mark[j],mark[j]=-1;
        }
    }
    return res;
}

例题

1.P3808 【模板】AC 自动机(简单版)

模板。

点击查看代码
struct AC_Automaton{
	int tr[maxn][26],tot,mark[maxn],fail[maxn];
	inline void insert(char *s){
		int u=0;
		for(int i=1;s[i];i++){
			if(!tr[u][s[i]-'a']){
				tr[u][s[i]-'a']=++tot;
			}
			u=tr[u][s[i]-'a'];
		}
		mark[u]++;
	}
	queue q;
	inline void build(){
		for(int i=0;i<26;i++){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;i++){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline int query(char *s){
		int u=0,res=0;
		for(int i=1;s[i];i++){
			u=tr[u][s[i]-'a'];
			for(int j=u;j&&mark[j]!=-1;j=fail[j]){
				res+=mark[j],mark[j]=-1;
			}
		}
		return res;
	}
}ac;
int n;
char s[maxn];
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		scanf("%s",s+1);
		ac.insert(s);
	}
	ac.build();
	scanf("%s",s+1);
	printf("%d\n",ac.query(s));
	return 0;
}

2.P3796 【模板】AC 自动机(加强版)

改变一下标记的设置,查询完扫一遍即可。

点击查看代码
struct AC_Automaton{
	int tr[maxn][26],tot,fail[maxn],mark[maxn];
	queue q;
	int cnt[205];
	inline void clear(){
		while(!q.empty()) q.pop();
		tot=0;
		memset(tr,0,sizeof(tr));
		memset(fail,0,sizeof(fail));
		memset(mark,0,sizeof(mark));
		memset(cnt,0,sizeof(cnt));
	}
	inline void insert(char *s,int k){
		int u=0;
		for(int i=1;s[i];i++){
			if(!tr[u][s[i]-'a']) tr[u][s[i]-'a']=++tot;
			u=tr[u][s[i]-'a'];
		}
		mark[u]=k;
	}
	inline void build(){
		for(int i=0;i<26;i++){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;i++){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline void query(char *s){
		int u=0;
		for(int i=1;s[i];i++){
			u=tr[u][s[i]-'a'];
			for(int j=u;j;j=fail[j]){
				if(mark[j]) cnt[mark[j]]++;
			}
		}
	}
}ac;
int n;
char s[205][105],t[maxn];
int main(){
	while(1){
		ac.clear();
		n=read();
		if(!n) break;
		for(int i=1;i<=n;i++){
			scanf("%s",s[i]+1);
			ac.insert(s[i],i);
		}
		ac.build();
		scanf("%s",t+1);
		ac.query(t);
		int maxx=-1;
		for(int i=1;i<=n;i++){
			maxx=max(maxx,ac.cnt[i]);
		}
		printf("%d\n",maxx);
		for(int i=1;i<=n;i++){
			if(ac.cnt[i]==maxx){
				printf("%s\n",s[i]+1);
			}
		}
	}
	return 0;
}

3.P5357 【模板】AC 自动机(二次加强版)

主要区别在于不保证模式串两两不同,也就是在字典树中打上的 \(mark\) 标记可能有多个,常规做法是对每个模式串记录 \(mark\) 标记的地址然后最后输出答案,然而会超时。

我们观察本题与简单版的题目实现过程,发现对于每个模式串来说,简单版只求是否包含,也就是只需要访问一次,而二次加强版要求的是包含的次数,也就是要访问多次,并且每次访问我们都要一直跳 \(fail\) 指针直到根节点,这就会造成超时。我们发现,对于一个节点状态 \(u\),其失配指针指向的节点状态 \(f\),以及 \(f\) 失配指针指向的节点状态 \(f'\),存在形如 \(u\rightarrow f\rightarrow f'\) 的关系,不仅仅是遍历顺序,更新答案也是如此。并且发现:就深度而言 \(u>f>f'\),于是只要 \(u\) 处有答案的更新,\(f\) 处一定有,\(f'\) 也一定有(前提是该状态是一个完整的模式串),那么如果只在初始访问的 \(u\) 处打上标记,最后再计算贡献,就能节省时间。

考虑拓扑排序,每次用当前队首 \(u\) 的标记答案去更新其失配指针 \(f\) 的答案,最后每个模式串找到其对应的 \(mark\) 标记地址即可。

点击查看代码
struct AC_Automaton{
	int tr[maxn][26],tot,fail[maxn],mark[maxn];
	int id[maxn],deg[maxn],cnt[maxn],ans[maxn];
	inline void insert(char *s,int k){
		int u=0;
		for(int i=1;s[i];i++){
			if(!tr[u][s[i]-'a']) tr[u][s[i]-'a']=++tot;
			u=tr[u][s[i]-'a'];
		}
		mark[u]++;
		id[k]=u;
	}
	inline void build(){
		queue q;
		for(int i=0;i<26;i++){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;i++){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					deg[fail[tr[u][i]]]++;
					q.push(tr[u][i]);
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline void query(char *s){
		int u=0;
		for(int i=1;s[i];i++){
			u=tr[u][s[i]-'a'];
			cnt[u]++;
		}
	}
	inline void topu(){
		queue q;
		for(int i=1;i<=tot;i++){
			if(!deg[i]) q.push(i);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			if(mark[u]) ans[u]=cnt[u];
			cnt[fail[u]]+=cnt[u];
			if(!(--deg[fail[u]])) q.push(fail[u]);
		}
	}
}ac;
int n;
char s[maxm];
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		scanf("%s",s+1);
		ac.insert(s,i);
	}
	ac.build();
	scanf("%s",s+1);
	ac.query(s);
	ac.topu();
	for(int i=1;i<=n;i++){
		printf("%d\n",ac.ans[ac.id[i]]);
	}
	return 0;
}

4.P5231 JSOI2012 玄武密码

按照常规方法建AC自动机,只不过不记录单词末尾,匹配文本串时每个访问的节点都打上标记,说明该状态是文本串的一个子串。最后在字典树上遍历每个模式串,记录标记位置的最大值即为答案。

点击查看代码
struct AC_Automaton{
	int tr[maxm][4],tot,fail[maxm],mark[maxm];
	inline int get(char c){
		if(c=='E') return 0;
		if(c=='S') return 1;
		if(c=='W') return 2;
		if(c=='N') return 3;
	}
	inline void insert(char *s){
		int u=0;
		for(int i=1;s[i];++i){
			if(!tr[u][get(s[i])]) tr[u][get(s[i])]=++tot;
			u=tr[u][get(s[i])];
		}
	}
	queue q;
	inline void build(){
		for(int i=0;i<4;++i){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<4;++i){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline void query(char *s){
		int u=0;
		for(int i=1;s[i];++i){
			u=tr[u][get(s[i])];
			for(int j=u;j&&!mark[j];j=fail[j]){
				mark[j]=1;
			}
		}
	}
	inline int get_ans(char *s){
		int u=0,res=0;
		for(int i=1;s[i];++i){
			u=tr[u][get(s[i])];
			if(mark[u]) res=i;
		}
		return res;
	}
}ac;
int n,m;
char s[maxn][105],t[maxm];
int main(){
	n=read(),m=read();
	scanf("%s",t+1);
	for(int i=1;i<=m;i++){
		scanf("%s",s[i]+1);
		ac.insert(s[i]);
	}
	ac.build();
	ac.query(t);
	for(int i=1;i<=m;i++){
		printf("%d\n",ac.get_ans(s[i]));
	}
	return 0;
}

5.P3966 TJOI2013 单词

题面多少有一点点迷,大致意思是给你 \(n\) 个模式串,文本串是由这些模式串相连但不想通组成的,求每个模式串的匹配个数,例如模式串集合 \(S=\text{\{a,aa,aaa\}}\),那么文本串就可以是 \(\text{a*aa*aaa}\),中间 \(\text{*}\) 存在的目的就是把两个模式串隔开。

那么题目由多文本串匹配多模式串就变成了单文本串匹配多模式串,直接打上去,最后一个数据点会超时。

发现每次只有失配指针指向一个被标记过的节点(即某个模式串的结尾),答案才会更新,所以我们建一个指针指向在失配指针形成的链中被标记的节点,这样更新答案就可以去除多余无用的状态了。

点击查看代码
int n,len;
char t[maxn+205];
struct AC_Automaton{
	int tr[maxn][26],tot,fail[maxn],mark[maxn],nxt[maxn];
	int id[maxn],cnt[maxn];
	inline void insert(char *s,int k){
		for(int i=1;s[i];++i){
			t[++len]=s[i];
		}
		t[++len]='*';
		int u=0;
		for(int i=1;s[i];++i){
			if(!tr[u][s[i]-'a']) tr[u][s[i]-'a']=++tot;
			u=tr[u][s[i]-'a'];
		}
        if(!mark[u]) mark[u]=k;
		id[k]=mark[u];
	}
	queue q;
	inline void build(){
		for(int i=0;i<26;++i){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;++i){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					if(mark[fail[tr[u][i]]]) nxt[tr[u][i]]=fail[tr[u][i]];
					else nxt[tr[u][i]]=nxt[fail[tr[u][i]]];
					q.push(tr[u][i]);
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline void query(){
		int u=0;
		for(int i=1;i<=len;++i){
			if(t[i]=='*'){
				u=0;
				continue;
			}
			u=tr[u][t[i]-'a'];
            if(mark[u]) cnt[mark[u]]++;
			for(int j=nxt[u];j;j=nxt[j]){
				cnt[mark[j]]++;
			}
		}
	}
}ac;
char s[maxn];
int main(){
	n=read();
	for(int i=1;i<=n;++i){
		scanf("%s",s+1);
		ac.insert(s,i);
	}	
	ac.build();
	ac.query();
	for(int i=1;i<=n;++i){
		printf("%d\n",ac.cnt[ac.id[i]]);
	}
	return 0;
}

6.P2444 POI2000 病毒

考虑一个合法的代码串在AC自动机上是如何表现的,其会通过不断跳 \(fail\) 指针来形成一个环,且不经过任何模式串的末尾标记。

于是我们建好AC自动机后,把每个模式串的连边以及失配指针都看做有向边,接着判环即可,一个优化是,如果 \(u\) 的失配指针指向的状态 \(f\) 是一个模式串(病毒),那么 \(u\) 也是不能访问的,最后 \(dfs\) 去判环时,要注意不能剪去有向边指向 \(0\) 的情况,因为存在一种情况是与字典树根形成了环,例如模式串为 \(\{01,11\}\) 时,就是 \(0\) 与字典树根形成了环,进一步理解可以表现为,当一个节点的某个子节点失配,其指针指向另一非病毒状态(也就是跳出了走向结尾的命运),并能按照此路径会到最初的节点,就是一个合法的代码串。同时,在标记时可以分“未访问、已访问且在当前路径中、已访问但不在当前路径中”三种状态,可以去掉重复的遍历。

点击查看代码
int n;
char s[maxn];
struct AC_Automaton{
	int tr[maxn][2],tot,fail[maxn];
	int mark[maxn],vis[maxn];
	queue q;
	inline void insert(){
		int u=0;
		for(register int i=1;s[i];++i){
			if(!tr[u][s[i]-'0']) tr[u][s[i]-'0']=++tot;
			u=tr[u][s[i]-'0'];
		}
		mark[u]=true;
	}
	inline void build(){
		if(tr[0][0]) q.push(tr[0][0]);
		if(tr[0][1]) q.push(tr[0][1]);
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(register int i=0;i<2;i++){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
					if(mark[fail[tr[u][i]]]) mark[tr[u][i]]=true;
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
	inline void dfs(int u){
		vis[u]=1;
		for(int i=0;i<2;i++){
			if(mark[tr[u][i]]) continue;
			if(vis[tr[u][i]]==1){
				printf("TAK\n");
				exit(0);
			}
			if(vis[tr[u][i]]==0) dfs(tr[u][i]);
		}
		vis[u]=-1;
	}
}ac;
int main(){
	n=read();
	for(register int i=1;i<=n;++i){
		scanf("%s",s+1);
		ac.insert();
	}
	ac.build();
	ac.dfs(0);
	printf("NIE\n");
	return 0;
}

AC自动机+dp

例题

1.P2322 HNOI2006 最短母串问题

\(n\) 很小,考虑状压,我们把每个节点以及此时的状态继续下来做 \(bfs\),然后找到一个状态全集的节点时,向前递归去输出该字符串。

点击查看代码
int n;
char s[605];
struct AC_Automaton{
	int tr[605][26],tot,fail[605],mark[605];
	inline void insert(int k){
		int u=0,l=strlen(s+1);
		for(int i=1;i<=l;++i){
			if(!tr[u][s[i]-'A']) tr[u][s[i]-'A']=++tot;
			u=tr[u][s[i]-'A'];
		}
		mark[u]|=(1<<(k-1));
	}
	inline void build(){
		queue q;
		for(int i=0;i<26;++i){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;++i){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
					mark[tr[u][i]]|=mark[tr[fail[u]][i]];
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
}ac;
int pre[maxm],ans[maxm],cnt,pos;
bool vis[605][1<<13];
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		scanf("%s",s+1);
		ac.insert(i);
	}
	ac.build();
	queue q;
	q.push(make_pair(0,0));
	vis[0][0]=1;
	while(!q.empty()){
		int u=q.front().first,sit=q.front().second;
		q.pop();
		if(sit==(1< stac;
			while(pos){
				stac.push(ans[pos]+'A');
				pos=pre[pos]; 
			}
			while(!stac.empty()){
				printf("%c",stac.top());
				stac.pop();
			}
			return printf("\n"),0;
		}
		for(int i=0;i<26;i++){
			int v=ac.tr[u][i];
			if(!vis[v][sit|ac.mark[v]]){
				vis[v][sit|ac.mark[v]]=1;
                                //这里的cnt与pos本质相同,都是bfs序,只不过cnt是预处理出了子节点,而pos表示当前的节点。
                                //或者理解为cnt是入队的顺序,而pos是出队的顺序,二者是等价的。
				pre[++cnt]=pos,ans[cnt]=i;
				q.push(make_pair(v,sit|ac.mark[v]));
			}
		}
		++pos;
	}
	return 0;
}

2.P4052 JSOI2007 文本生成器

AC自动机上dp的常规做法,令 \(dp(i,j)\) 表示长度为 \(i\),末位字符在字典树上的标号为 \(j\) 的方案数。考虑减法原理,用长度为 \(m\) 的字符串总方案数,即 \(26^m\),减去不包含任何一个模式串的情况数,即为答案。

\[\text{至少包含一个的方案数}=\text{总方案数}-\text{一个也不包含的方案数} \]

同上面几题,一个状态的失配指针若是模式串标记处,那么这个状态也是模式串标记处,于是可以在构建失配指针时把是否包含模式串预处理出,接着只要 \(tr(j,k)\) 不是模式串的结尾,那么 \(dp(i+1,tr(j,k))\) 就可以由 \(dp(i,j)\) 转移,刷表法即可。

点击查看代码
int n,m;
char s[105];
struct AC_Automaton{
	int tr[maxn][26],tot,fail[maxn],mark[maxn];
	inline void insert(){
		int u=0;
		for(int i=1;s[i];++i){
			if(!tr[u][s[i]-'A']) tr[u][s[i]-'A']=++tot;
			u=tr[u][s[i]-'A'];
		}
		mark[u]=1;
	}
	inline void build(){
		queue q;
		for(int i=0;i<26;++i){
			if(tr[0][i]) q.push(tr[0][i]);
		}
		while(!q.empty()){
			int u=q.front();
			q.pop();
			for(int i=0;i<26;++i){
				if(tr[u][i]){
					fail[tr[u][i]]=tr[fail[u]][i];
					q.push(tr[u][i]);
					mark[tr[u][i]]|=mark[tr[fail[u]][i]];
				}
				else tr[u][i]=tr[fail[u]][i];
			}
		}
	}
}ac;
inline int q_pow(int x,int p){
	int ans=1;
	while(p){
		if(p&1) ans=ans*x%mod;
		x=x*x%mod;
		p>>=1;
	}
	return ans;
}
int dp[105][maxn];
int main(){
	n=read(),m=read();
	for(int i=1;i<=n;i++){
		scanf("%s",s+1);
		ac.insert();
	}
	ac.build();
	dp[0][0]=1;
	for(int i=0;i

3.P4045 JSOI2009 密码