树链剖分&动态树学习笔记
如何将树上的一段路径转化为区间问题?我们可能会想到树上莫队中利用欧拉序性质的做法,但其不具有普适性,对于很多区间问题,难以将出现两次的元素减掉。而树链剖分与动态树都可以很好地解决这类问题。
树链剖分
树链剖分也称为重链剖分,适用于形态结构不发生变化的树(即静态)。
将树上所有边分为重边和轻边,每个节点与其子节点的所有连边中,有且仅有一条重边(叶子结点除外),其余为轻边。重边的求法:对于每个节点,考察其所有子节点,子树中包含节点最多的子节点称为“最重”。重边连接此节点与其“最重”的子节点。若有多个子节点同为“最重”,任选其一。
连在一起的重边从上而下形成重链。所有节点都包含在重链中,若一个节点上下均无重边,则其独自构成一个重链。此时有性质:从根节点出发的任意一条路径“穿过”的重链数不超过 \(\log{N}\)。
证明:每离开一条重链,走进另一条重链,必然使得其子树大小变为原来的 \(\frac{1}{k}\),其中 \(k\) 表示当前节点的子节点个数,且必然 \(≥2\)。于是得证。
对此树求DFS序(注意不是欧拉序,每个节点仅出现一次)。DFS优先拓展重边。此时有性质:一条重链上的所有节点在DFS序中连在一起。
于是,树上的任意一条路径,一定可以被拆分为至多 \(\log{N}\) 个连续区间。我们对每个区间分别求解,最后加在一起,即为答案。
对于DFS序中的元素,常用线段树/树状数组/平衡树等 \(\log{N}\) 级别的数据结构维护。故树链剖分的时间复杂度为 \(O(N\log^2{N})\)。但由于路径穿过的重链往往远不及 \(\log{N}\) 个,且每个重链对应的区间长度也较小,故此算法常数很小,实际效率很高。
模板题
Code
#include
#include
#define ll long long
using namespace std;
const int N = 1e5 + 5;
int n, m, a[N];
int head[N], nxt[N << 1], ver[N << 1], tot;
int id[N], sz[N], idx, dfn[N], d[N], son[N], fa[N], top[N];
struct Tree {
int l, r;
ll sum, add;
} t[N << 2];
void Add(int x, int y) {
nxt[++tot] = head[x]; head[x] = tot; ver[tot] = y;
}
void dfs1(int x) {
sz[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if (d[y]) continue;
d[y] = d[x] + 1; fa[y] = x;
dfs1(y);
sz[x] += sz[y];
if (sz[y] > sz[son[x]]) son[x] = y;
}
}
void dfs2(int x) {
dfn[++idx] = x; id[x] = idx;
if (!son[x]) return ;
top[son[x]] = top[x]; dfs2(son[x]);
for (int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if (d[y] < d[x] || y == son[x]) continue;
top[y] = y; dfs2(y);
}
}
void pushup(int p) {
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}
void pushdown(int p) {
if (!t[p].add) return ;
Tree &u = t[p], &l = t[p << 1], &r = t[p << 1 | 1];
l.add += u.add; l.sum += u.add * (l.r - l.l + 1);
r.add += u.add; r.sum += u.add * (r.r - r.l + 1);
u.add = 0;
}
void Build(int l, int r, int p) {
t[p].l = l; t[p].r = r;
if (l == r) {
t[p].sum = a[dfn[l]]; return ;
}
int mid = l + r >> 1;
Build(l, mid, p << 1); Build(mid + 1, r, p << 1 | 1);
pushup(p);
}
void Insert(int l, int r, int v, int p) {
if (l <= t[p].l && t[p].r <= r) {
t[p].add += v; t[p].sum += (ll)v * (t[p].r - t[p].l + 1); return ;
}
int mid = t[p].l + t[p].r >> 1;
pushdown(p);
if (l <= mid) Insert(l, r, v, p << 1);
if (r > mid) Insert(l, r, v, p << 1 | 1);
pushup(p);
}
ll Query(int l, int r, int p) {
if (l <= t[p].l && t[p].r <= r) return t[p].sum;
int mid = t[p].l + t[p].r >> 1;
ll res = 0;
pushdown(p);
if (l <= mid) res += Query(l, r, p << 1);
if (r > mid) res += Query(l, r, p << 1 | 1);
pushup(p);
return res;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
Add(x, y); Add(y, x);
}
d[1] = 1; dfs1(1); top[1] = 1; dfs2(1);
Build(1, n, 1);
scanf("%d", &m);
while (m--) {
int opt, x, y, z;
scanf("%d%d", &opt, &x);
if (opt == 1) {
scanf("%d%d", &y, &z);
while (top[x] != top[y]) { //树剖部分
if (d[top[x]] < d[top[y]]) swap(x, y);
Insert(id[top[x]], id[x], z, 1);
x = fa[top[x]];
}
if (d[x] < d[y]) swap(x, y);
Insert(id[y], id[x], z, 1); //剩下一段单独处理
}
else if (opt == 2) {
scanf("%d", &z);
Insert(id[x], id[x] + sz[x] - 1, z, 1);
}
else if (opt == 3) {
scanf("%d", &y);
ll res = 0;
while (top[x] != top[y]) {
if (d[top[x]] < d[top[y]]) swap(x, y);
res += Query(id[top[x]], id[x], 1);
x = fa[top[x]];
}
if (d[x] < d[y]) swap(x, y);
res += Query(id[y], id[x], 1);
printf("%lld\n", res);
}
else printf("%lld\n", Query(id[x], id[x] + sz[x] - 1, 1));
}
return 0;
}
对于静态的树来说,树剖还是很不错的算法。
动态树
这就是大名鼎鼎的LCT(Link-Cut Tree)。支持动态地加边、删边,并兼容树链剖分的几乎所有操作。时间复杂度仅为 \(O(N\log{N})\)。
实际上LCT维护的是森林,以下以一棵树进行讨论。
与树链剖分类似地,将所有边分别实边和虚边,每个节点可以向下连最多一条实边,但也可以不连。实边构成实边链,每个节点包含在一条链中。
接下来就是LCT的关键。对每条实边链,用一个Splay维护,Splay的中序遍历对应链中从上往下的顺序。实边链中最上面的节点不一定是Splay的根节点。
然后,将这些Splay连接起来,构成一棵树,以下称为Splay的树。对于每条实边链,连接其最上面节点(设为 \(x\) )与其父节点(设为 \(y\) )的虚边在Splay的树中连接这条实边链对应Splay的根节点与 \(y\)(连接方法之后讨论)。而如何通过Splay的树对应到原树?实边显然可以通过每个Splay的中序遍历得到,虚边则通过每个Splay中最“左”的点与其根节点向上连到的点得到。
接着讨论如何连接Splay的树。在以下的操作中,要求每个Splay要隔离,但又要求向上追溯到原树的根(为什么?1.LCT维护的实际是森林;2.LCT是无根树,之后的操作会进行“换根”)。我们的处理方法是:每个Splay的根节点向上的连边仅有“向上连”而没有“向下连”,即“认父不认子”。Splay和rotate中也要特殊判断,防止转错。
以下介绍LCT的几大操作。
access(int x)
作用是,在原树中构造出从根到 \(x\) 的实边路径,且不再包含 \(x\) 的子节点。并将 \(x\) splay到其Splay的根。
将 \(x\) splay到根节点(仅为其splay根节点),找到其父节点 \(y\),将 \(y\) splay到根节点。则此时 \(y\) 一定没有后继。将 \(x\) 为根的子树全部接到 \(y\) 的右儿子。不断循环该操作即可。
makeroot(int x)
将 \(x\) 变成原树中的根节点。
只需access(x),此时 \(x\) 到根的路径构成一棵splay,将其翻转即可。
findroot(int x)
找到 \(x\) 在原树中的根节点。附属作用,将根节点转到splay中的根。
只需access(x),再一路向左即可。最后记得splay。
split(int x, int y)
若 \(x,y\) 在原树中连通,则建立一条 \(x\) 到 \(y\) 的实边路径。
makeroot(x),access(y)。
link(int x, int y)
若 \(x,y\) 不连通,加入边 \((x,y)\)。
先makeroot(x)。若findroot(y)!=x,则连接。因为不确定此时 \(x\) 是否有向下的实边,故只能连虚边。因为已经makeroot(x)了,故可以直接将 \(x\) 的父亲设为 \(y\)。
cut(int x, int y)
若 \(x,y\) 之间有边,删除此边。
先makeroot(x)。若 \(x,y\) 有边,则此时 findroot(y)==x。且因为前面的findroot(y),\(y\) 与 \(x\) 一定在同一条实边链上,则 \(y\) 为 \(x\) 的后继。且根据rotate的操作,\(y\) 一定为 \(x\) 的右子节点。于是:
void cut(int x, int y) {
makeroot(x);
if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
t[x].s[1] = t[y].p = 0; pushup(x);
}
}
至此为LCT的所有操作。第一次写代码时可能很难,但背模板还是可以的。LCT与网络流有些相似,不会在代码的实现上有很多变通,基本只有pushup和pushdown需要注意。
关于LCT的时间复杂度,为 \(O(N\log{N})\),常数较大。目前还不会证明。毕竟Splay就已经很玄学了,这么多个Splay更玄学……
模板题
Code
#include
#include
using namespace std;
const int N = 1e5 + 5;
int n, m, stack[N], top;
struct Node {
int s[2], p, v, sum, rev;
} t[N];
bool isroot(int x) {
return t[t[x].p].s[0] != x && t[t[x].p].s[1] != x;
}
void pushup(int x) {
t[x].sum = t[t[x].s[0]].sum ^ t[t[x].s[1]].sum ^ t[x].v;
}
void pushrev(int x) {
swap(t[x].s[0], t[x].s[1]); t[x].rev ^= 1;
}
void pushdown(int x) {
if (!t[x].rev) return ;
pushrev(t[x].s[0]); pushrev(t[x].s[1]); t[x].rev = 0;
}
void rotate(int x) {
int y = t[x].p, z = t[y].p, k = t[y].s[1] == x;
if (!isroot(y)) t[z].s[t[z].s[1] == y] = x; t[x].p = z;
t[y].s[k] = t[x].s[k ^ 1]; t[t[x].s[k ^ 1]].p = y;
t[x].s[k ^ 1] = y; t[y].p = x;
pushup(y); pushup(x);
}
void splay(int x) {
int p = x;
stack[++top] = p;
while (!isroot(p)) stack[++top] = p = t[p].p;
while (top) pushdown(stack[top--]);
while (!isroot(x)) {
int y = t[x].p, z = t[y].p;
if (!isroot(y))
if (t[z].s[1] == y ^ t[y].s[1] == x) rotate(x);
else rotate(y);
rotate(x);
}
}
void access(int x) {
int z = x;
for (int y = 0; x; y = x, x = t[x].p) {
splay(x); t[x].s[1] = y; pushup(x);
}
splay(z);
}
void makeroot(int x) {
access(x); pushrev(x);
}
int findroot(int x) {
access(x);
while (t[x].s[0]) {
pushdown(x); x = t[x].s[0];
}
splay(x);
return x;
}
void split(int x, int y) {
makeroot(x); access(y);
}
void link(int x, int y) {
makeroot(x);
if (findroot(y) != x) t[x].p = y;
}
void cut(int x, int y) {
makeroot(x);
if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
t[x].s[1] = t[y].p = 0; pushup(x);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &t[i].v);
while (m--) {
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if (opt == 0) {
split(x, y); printf("%d\n", t[y].sum);
}
else if (opt == 1) link(x, y);
else if (opt == 2) cut(x, y);
else {
splay(x); t[x].v = y; pushup(x);
}
}
return 0;
}
例:Acwing999 魔法森林
有一张 \(n\) 个节点 \(m\) 条边的无向图,每条边有两个权值 \(a_i,b_i\)。要求找到一条从 \(1\) 到 \(n\) 的路径,要求路径上 \(a\) 的最大值与 \(b\) 的最大值的和最小。\(1≤n≤50000\),\(0≤m≤100000\)。
首先,若 \(a\) 的最大值已经确定(设为 \(\bar{a}\)),则无向图中可以走的边是确定的。问题转化为求解此图中从 \(1\) 到 \(n\) 的路径使最大值最小,以 \(b\) 为权值。可以想到一种类似最小生成树的做法。将边按 \(b\) 从小到大排序,不断加边,直到连通,并查集维护。
可以想到二分。但是,此题中虽然随着 \(\bar{a}\) 的增加图中边数增加,但同时最小生成树的值会减小,最终答案不具有单调性。那么对于所有 \(\bar{a}\),均需更新一次答案。边数为 \(100000\),那么就需要 \(log\) 级别的算法。
那么我们将所有边按 \(a\) 排序,则 \(\bar{a}\) 不断增大,不断向图中加边。我们需要维护 \(1\) 到 \(n\) 的最大值最小路径。不难想到维护最小生成树。加边时,用类似求次小生成树的方法,加入这条边,若构成环,删除环中权值最大的边。
那么,在最小生成树中如何 \(\log{N}\) 维护 \(1\) 到 \(n\) 最大值最小路径呢?可以想到LCT。利用点边转化的技巧,将边转化为点,边权转化为点权,原来点的权值为0,则可以实现。
此题需要卡常,在判断点连通时可以舍弃findroot,使用并查集。
Code
#include
#include
using namespace std;
const int N = 5e4 + 5, M = 1e5 + 5;
int n, m, v[N + M], fa[N];
int stack[N + M], top;
struct Edge {
int x, y, a, b;
bool operator <(const Edge &o) const {
return a < o.a;
}
} e[M];
struct Node {
int s[2], p, mx, rev;
} t[N + M];
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
return x;
}
int get(int x) {
if (x == fa[x]) return x;
return fa[x] = get(fa[x]);
}
bool isroot(int x) {
return t[t[x].p].s[0] != x && t[t[x].p].s[1] != x;
}
void pushrev(int x) {
swap(t[x].s[0], t[x].s[1]); t[x].rev ^= 1;
}
void pushup(int x) {
t[x].mx = x;
for (int i = 0; i < 2; i++)
if (v[t[t[x].s[i]].mx] > v[t[x].mx])
t[x].mx = t[t[x].s[i]].mx;
}
void pushdown(int x) {
if (!t[x].rev) return ;
pushrev(t[x].s[0]); pushrev(t[x].s[1]); t[x].rev = 0;
}
void rotate(int x) {
int y = t[x].p, z = t[y].p, k = t[y].s[1] == x;
if (!isroot(y)) t[z].s[t[z].s[1] == y] = x; t[x].p = z;
t[y].s[k] = t[x].s[k ^ 1]; t[t[x].s[k ^ 1]].p = y;
t[x].s[k ^ 1] = y; t[y].p = x;
pushup(y); pushup(x);
}
void splay(int x) {
int p = x;
stack[++top] = p;
while (!isroot(p)) stack[++top] = p = t[p].p;
while (top) pushdown(stack[top--]);
while (!isroot(x)) {
int y = t[x].p, z = t[y].p;
if (!isroot(y))
if (t[z].s[1] == y ^ t[y].s[1] == x) rotate(x);
else rotate(y);
rotate(x);
}
}
void access(int x) {
int z = x;
for (int y = 0; x; y = x, x = t[x].p) {
splay(x); t[x].s[1] = y; pushup(x);
}
splay(z);
}
void makeroot(int x) {
access(x); pushrev(x);
}
int findroot(int x) {
access(x);
while (t[x].s[0]) {
pushdown(x); x = t[x].s[0];
}
splay(x);
return x;
}
void split(int x, int y) {
makeroot(x); access(y);
}
void link(int x, int y) {
makeroot(x);
if (findroot(y) != x) t[x].p = y;
}
void cut(int x, int y) {
makeroot(x);
if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
t[x].s[1] = t[y].p = 0; pushup(x);
}
}
int main() {
int ans = 0x7fffffff;
n = read(); m = read();
for (int i = 1; i <= m; i++) e[i] = (Edge){read(), read(), read(), read()};
sort(e + 1, e + m + 1);
for (int i = 1; i <= n + m; i++) {
t[i].mx = i;
if (i <= n) fa[i] = i;
else v[i] = e[i - n].b;
}
for (int i = 1; i <= m; i++) {
int x = e[i].x, y = e[i].y;
if (get(x) == get(y)) {
split(x, y);
if (v[t[y].mx] > e[i].b) {
int p = t[y].mx;
cut(e[p - n].x, p); cut(e[p - n].y, p);
link(x, i + n); link(y, i + n);
}
}
else {
link(x, i + n); link(y, i + n); fa[get(x)] = get(y);
}
if (get(1) == get(n)) {
split(1, n); ans = min(ans, e[i].a + v[t[n].mx]);
}
}
printf("%d\n", ans == 0x7fffffff ? -1 : ans);
return 0;
}