aho-corasick automaton AC自动机的理解

最近在学编译原理, 里面在3.3节词法单元识别后面就提到了这个算法然后根据网上资料自己做了一遍,只支持英文字母
可能最后的效果没有oi-wiki上的效率高

1. 背景

大概涉及到的知识:

  • Trie 树,一种字典树,可以看这里做的挺直观的
  • BFS 广度优先搜索Trie树
  • 状态压缩,随便做的小优化,可能有负效果对于数据量小(x
    主要思想就是用int/long这种数据类型的每一位通过位运算当作bool而不是单独声明bool类型, 主要的算法就
    • (n>>k)&1取出n的第k位的数据
    • n^(1<<k) n的第k位取反
  • AC自动机的失配(失效)算法

    2. 思想

    主要流程就, 先构造一棵trie树,然后用bfs构造每个节点的失配位置(最后效果就类似于DFA确定有限状态自动机),然后再遍历以匹配出结果

    3. 代码

    3.1 Trie树

    比较简易的做法就是
    1
    2
    3
    4
    struct Node{
    Node* next[26];
    bool end;
    };
    用数组的下标表示对应边(anscii, 比如 char - 'a')和对应的下一个节点
    这里优化一下, 不然内存地址太分散然后其实这个数据结构不是很有必要
    所以本文用一个二维数组代替这一套(nodes[x][y]x是节点的编号,y是y+'a'的边指向的节点 0<=y<=25, 比如a的边就是'a'-'a'=0), 具象化表示就参考oi-wiki
    1
    2
    3
    4
    constexpr int s = 50; 
    int nodes[s][26]{0};
    int now = 1;
    bool end[s];
  • s是数组大小,因为后面状态压缩的时候还要用到就提取出来,作用和#define一样
  • end是代表对应下标的节点是不是单词的结尾(判断匹配是否成功)
  • now是下一个节点应该是哪个下标

然后就是写add()或者insert方法

1
2
3
4
5
6
7
8
9
void trieAdd(const std::string &text) {   // 插入单词
int p = 0; // 当前插入到哪个节点, 一开始从根节点开始
for (char each: text) { // 循环每个字符
if (nodes[p][each - 'a'] == 0) // 如果没有对应的字符节点就新建一个
nodes[p][each - 'a'] = now++; // 存入对应字符节点的下标
p = nodes[p][each - 'a']; // 把目前插入到的节点指向到对应字符的子节点
}
end[p] = true; // 单词结尾
}

3.1.1 优化endNodes

这里做了一个可有可无的优化, 就把代表end节点从一个bool数组改成一个int32_t数组, 然后用32位中每一位表示一个节点是否为接受节点(0或1)
为了使数据方便取余, 最好用2的整次方为数位长度(如8, 16, 32等)

1
2
3
4
5
constexpr int s = 50; 
int nodes[s][26]{0};
int now = 1;
constexpr int bitW = 32; // 必须是2的n次方
int32_t endNodes[s / bitW + (s % bitW > 0 ? 1 : 0)]{0}; // 该下标对应的节点是否为单词的结尾

所以上面的声明代码就变成这样, 和add方法

1
2
3
4
5
6
7
8
9
void trieAdd(const std::string &text) {   // 插入单词
int p = 0; // 当前插入到哪个节点, 一开始从根节点开始
for (char each: text) { // 循环每个字符
if (nodes[p][each - 'a'] == 0) // 如果没有对应的字符节点就新建一个
nodes[p][each - 'a'] = now++; // 存入对应字符节点的下标
p = nodes[p][each - 'a']; // 把目前插入到的节点指向到对应字符的子节点
}
endNodes[p / bitW] = endNodes[p / bitW] ^ (1 << (p & (bitW - 1))); // 表明该节点是单词的结尾, endNodes[p/bitW]^(1<<n)是对n位取反, 后面的p&63为:当取余的除数为2^n时,可以用p&(k-1)代替,也就是p&63
}

然后当要去第p个节点是否是接受(结束)节点时:

1
if ((endNodes[p / bitW] >> (p & (bitW - 1))) & 1) 

3.2 失配算法

主要就先准备一个fail数组(在这里我设立的是从1开始, 就不用提前赋值全部元素为-1, 因为元素可以是0), 下标代表对应的节点失配后跳转到哪个节点
然后用一个队列(queue)确保BFS因为先入先出
然后就循环每一个节点和子节点找失配位置
具体流程:
先从根节点开始 -> 依循bfs也就是宽度(广度)优先顺序搜索每个子节点->先遍历每个子节点的每条边->当边不为空(指向的子节点!=0, 因为边是不可能指向根节点), 对于每个子节点的边有3种情况:

  • 如果父节点是0也就是根节点, 那当前边指向的子节点的失配位置就是根节点也就是0
  • 如果父节点是失配位置有当前边, 那当前边指向的子节点的失配位置就是父节点的失配位置
  • 如果以上都不是, 把父节点的失配位置看作这条边的父节点然后继续上面的流程知道父节点是根节点

执行上面的流程找到失配位置后把当前边对应的子节点压入队列然后开始下一条边

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
int fail[s]{0};
void trieFail() { // 失配跳转配置
int n; // now, 也就是当前节点的下标
queue<int> ns; // 待处理的节点下标
ns.push(0); // 一开始从根节点开始
while (true) {
if (ns.empty()) break; // 如果没有待处理的就推出
n = ns.front(); // 取出最上面待处理的并弹出
ns.pop();
for (int i = 0; i < 26; i++) { // 循环该节点下全部子节点
if (nodes[n][i] != 0) {
{ // 求失配跳转位置
int pp = n; // 父节点, 因为遍历的是当前节点的子节点, 所以父节点就是当前节点
int nn = nodes[n][i]; // 当前节点, 也就是遍历的子节点
if (pp == 0) { // 如果父节点是根节点, 那失配节点就是根节点
fail[nn] = 1;
} else
while (true) { // 不断循环父节点的失配位置, 或者父节点的失配位置的失配位置直到到根节点或者找到合适的
if (fail[pp] != 0) pp = fail[pp] - 1; // 如果父节点有失配位置, 就把父节点的失配位置视为父节点
else
throw invalid_argument(to_string(nn) + "-" + to_string(pp) + "x"); // unreachable
if (nodes[pp][i] != 0) { // 如果该节点(父节点的某个失配位置)有i的边, 失配位置就是那个边指向的节点
fail[nn] = nodes[pp][i] + 1;
break;
}
if (pp == 0) { // 如果父节点是根节点, 那失配节点就直接是根节点
fail[nn] = 1;
break;
}
}
}
ns.push(nodes[n][i]); // 把该节点列入待处理队列
}
}
}
}

3.3 寻找

就循环每条边, 如果匹配不上就移动到失配节点继续

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
vector<string> trieFind(const std::string &text) { // 查找一个单词
int p = 0; // 从根节点开始找
vector<string> rev; // 所有匹配到的串
string re;
for (char each: text) { // 循环每一个字符
while (true) {
if (p == 0) re = "";
if (nodes[p][each - 'a'] != 0) { // 如果对应字符的边存在,就移动到那个子节点
p = nodes[p][each - 'a'];
re += each;
if ((endNodes[p / bitW] >> (p & (bitW - 1))) & 1) { // 如果最后一个节点是单词结束就代表匹配,p&63=p%bitW,(endNodes[p/bitW]>>n)&1是取第n位的数据
rev.push_back(re);
p = fail[p] - 1;
if(re.length() > 1)
re = re.substr(1);
}
break;
} else {
if (fail[p] == 0) break;
p = fail[p] - 1; // 否则转移(因为fail是从1开始的所以减一)
if(re.length() > 1)
re = re.substr(1);
}
}
}
if ((endNodes[p / bitW] >> (p & (bitW - 1))) & 1) { // 如果节点是匹配结束节点
rev.push_back(re);
}
return rev;
}

3.4 测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
string printVector(vector<string> t) {
string re;
for (auto &i: t)
re += i + " ";
return re;
}
int main() {
trieAdd("gsafsfawdasca");
trieAdd("gdscascwadawxszg");
trieAdd("ewyuoascasdwo");
trieAdd("csaxawewddwaqw");
trieFail();
cout << to_string(now) << endl;
cout << printVector(trieFind("csaxawewddwaqwfsacasfegdscascwadawxszgagsafsfawdascawgrhherasdawdwzz")) << endl
<< printVector(trieFind("fsacasfegdscascwadawxszgsafsfawdascawgrhherasdawdwzzcsaxawewddwaqw")) << endl
<< printVector(trieFind("cegergecsaxaweqsdawryyrgte")) << endl;
for (int i = 0; i < s; i++)
cout << i << " ";
cout << endl;
for (auto a: fail)
cout << a - 1 << " ";
return 0;
}

3.5 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
* Copyright (c) Eritque arcus
* Date: 2022/1/11
* MIT license
*/
#include <iostream>
#include <queue>
#include <vector>
using namespace std;
constexpr int s = 50; // 有多少个节点
constexpr int bitW = 32; // 必须是2的n次方
int nodes[s][26]{0}; // 储存全部的节点
int fail[s]{0}; // 失配跳转
int now = 1; // 下一个插入节点的地方
int32_t endNodes[s / bitW + (s % bitW > 0 ? 1 : 0)]{0}; // 该下标对应的节点是否为单词的结尾
void trieFail() { // 失配跳转配置
int n; // now, parent
queue<int> ns; // 待处理的节点下标
ns.push(0); // 一开始从根目录开始
while (true) {
if (ns.empty()) break; // 如果没有待处理的就推出
n = ns.front(); // 取出最上面待处理的
ns.pop();
for (int i = 0; i < 26; i++) { // 循环该节点下全部子节点
if (nodes[n][i] != 0) {
{ // 求失配跳转位置
int pp = n; // 父节点, 因为遍历的是当前节点的子节点, 所以父节点就是当前节点
int nn = nodes[n][i]; // 当前节点, 也就是遍历的子节点
if (pp == 0) { // 如果父节点是根节点, 那失配节点就是根节点
fail[nn] = 1;
} else
while (true) { // 不断循环父节点的失配位置, 或者父节点的失配位置的失配位置直到在根节点寻找失配位置或者找到合适的
if (fail[pp] != 0) pp = fail[pp] - 1; // 如果父节点有失配位置, 就移动过去
else
throw invalid_argument(to_string(nn) + "-" + to_string(pp) + "x"); // unreachable
if (nodes[pp][i] != 0) { // 如果该节点(父节点的某个失配位置)有i的边就移动过去
fail[nn] = nodes[pp][i] + 1;
break;
}
if (pp == 0) { // 如果父节点是根节点, 那失配节点就直接是根节点
fail[nn] = 1;
break;
}
}
}
ns.push(nodes[n][i]); // 把该子节点列入待处理栈
}
}
}
}
void trieAdd(const std::string &text) { // 插入单词
int p = 0; // 当前插入到哪个节点, 一开始从根节点开始
for (char each: text) { // 循环每个字符
if (nodes[p][each - 'a'] == 0) // 如果没有对应的字符节点就新建一个
nodes[p][each - 'a'] = now++; // 存入对应字符节点的下标
p = nodes[p][each - 'a']; // 把目前插入到的节点指向到对应字符的子节点
}
endNodes[p / bitW] = endNodes[p / bitW] ^ (1 << (p & (bitW - 1))); // 表明该节点是单词的结尾, endNodes[p/bitW]^(1<<n)是对n位取反, 后面的p&63为:当取余的除数为2^n时,可以用p&(k-1)代替,也就是p&63
}
vector<string> trieFind(const std::string &text) { // 查找一个单词
int p = 0; // 从根节点开始找
vector<string> rev; // 所有匹配到的串
string re;
for (char each: text) { // 循环每一个字符
while (true) {
if (p == 0) re = "";
if (nodes[p][each - 'a'] != 0) { // 如果对应字符的边存在,就移动到那个子节点
p = nodes[p][each - 'a'];
re += each;
if ((endNodes[p / bitW] >> (p & (bitW - 1))) & 1) { // 如果最后一个节点是单词结束就代表匹配,p&63=p%bitW,(endNodes[p/bitW]>>n)&1是取第n位的数据
rev.push_back(re);
p = fail[p] - 1;
if(re.length() > 1)
re = re.substr(1);
}
break;
} else {
if (fail[p] == 0) break;
p = fail[p] - 1; // 否则转移(因为fail是从1开始的所以减一)
if(re.length() > 1)
re = re.substr(1);
}
}
}
if ((endNodes[p / bitW] >> (p & (bitW - 1))) & 1) { // 如果节点是匹配结束节点
rev.push_back(re);
}
return rev;
}
string printVector(vector<string> t) {
string re;
for (auto &i: t)
re += i + " ";
return re;
}
int main() {
trieAdd("aaaaaa");
trieAdd("bbbbbb");
trieAdd("ababab");
trieAdd("cacsafasfasfsgefacsdcewg");
trieFail();
cout << "节点数:" + to_string(now) << endl;
cout << printVector(trieFind("aaaaaaabbbbbbbbbacacsafasfasfsgefacsdcewgbaabcaaaaaaabbbbbbbbbbbbabbbbbbbbbbbbaav")) << endl;
for (int i = 0; i < s; i++)
cout << i << " ";
cout << endl;
for (auto a: fail)
cout << a - 1 << " ";
return 0;
}

4. 可视化扩展

可以通过python的visual-automata来画

4.1 cpp 更改

这部分就不考虑优化了(
先声明一个vector<int> final;储存最终节点
然后在triedAdd方法push_back

1
2
3
4
5
6
7
8
9
10
void trieAdd(const std::string &text) {   // 插入单词
int p = 0; // 当前插入到哪个节点, 一开始从根节点开始
for (char each: text) { // 循环每个字符
if (nodes[p][each - 'a'] == 0) // 如果没有对应的字符节点就新建一个
nodes[p][each - 'a'] = now++; // 存入对应字符节点的下标
p = nodes[p][each - 'a']; // 把目前插入到的节点指向到对应字符的子节点
}
final.push_back(p);
endNodes[p / bitW] = endNodes[p / bitW] ^ (1 << (p & (bitW - 1))); // 表明该节点是单词的结尾, endNodes[p/bitW]^(1<<n)是对n位取反, 后面的p&63为:当取余的除数为2^n时,可以用p&(k-1)代替,也就是p&63
}

要在main方法调用后

1
2
3
4
5
6
7
nlohmann::json j;
for(int i = 0; i < now;i++)
for(int a = 0; a < 26; a ++)
j[0][i][a] = nodes[i][a];
j[1] = fail;
j[2] = final;
cout<<endl<<j.dump();

就可以拿到json格式的转换表

4.2 python

先根据pypi里的安装依赖, 然后把上面的json数据复制到下面代码里的jsonData然后运行就可以了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#!/usr/bin/python
# author: Eritque arcus
import json
from visual_automata.fa.nfa import VisualNFA

jsonData = "[[[1,7,18,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[2,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,15,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,17,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[19,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,21,0,0,0,0,0,0,0],[22,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,23,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[24,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,25,0,0,0,0,0,0,0],[0,0,0,0,0,26,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,28,0,0,0,0,0,0,0],[0,0,0,0,0,29,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,30,0,0,0,0,0,0,0],[0,0,0,0,0,0,31,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,33,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[34,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,35,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,36,0,0,0,0,0,0,0],[0,0,0,37,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,38,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,39,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,40,0,0,0],[0,0,0,0,0,0,41,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]],[0,1,2,3,4,5,6,1,8,9,10,11,12,8,2,14,15,16,1,2,19,1,2,1,2,1,1,2,1,1,1,1,1,1,2,19,1,1,19,1,1,1,0,0,0,0,0,0,0,0],[6,12,17,41]]"
a = json.loads(jsonData)
t = {}
for j in range(len(a[0])):
it = {} # inside transition
for i in range(len(a[0][j])):
if a[0][j][i] != 0:
it[chr(i + ord('a'))] = {"q" + str(a[0][j][i])}
if a[1][j] > 0:
it["fail"] = {"q" + str(a[1][j] - 1)}
t["q" + str(j)] = it
inputS = set([chr(b + ord('a')) for b in range(26)])
inputS.add("fail")
nfa = VisualNFA(
states=set(["q" + str(b) for b in range(len(a[0]))]),
input_symbols=inputS,
transitions=t,
initial_state="q0",
final_states=set(["q" + str(b) for b in a[2]])
)
nfa.show_diagram(view=True)

这里用nfa而不是dfa是因为dfa画起来线太多很乱, 因为每个点有26根可能的线, 所以就把失配位置用fail线指向

4.3 效果

nfa

end