树链剖分&动态树学习笔记


如何将树上的一段路径转化为区间问题?我们可能会想到树上莫队中利用欧拉序性质的做法,但其不具有普适性,对于很多区间问题,难以将出现两次的元素减掉。而树链剖分与动态树都可以很好地解决这类问题。

树链剖分

树链剖分也称为重链剖分,适用于形态结构不发生变化的树(即静态)。
将树上所有边分为重边轻边,每个节点与其子节点的所有连边中,有且仅有一条重边(叶子结点除外),其余为轻边。重边的求法:对于每个节点,考察其所有子节点,子树中包含节点最多的子节点称为“最重”。重边连接此节点与其“最重”的子节点。若有多个子节点同为“最重”,任选其一。
连在一起的重边从上而下形成重链。所有节点都包含在重链中,若一个节点上下均无重边,则其独自构成一个重链。此时有性质:从根节点出发的任意一条路径“穿过”的重链数不超过 \(\log{N}\)
证明:每离开一条重链,走进另一条重链,必然使得其子树大小变为原来的 \(\frac{1}{k}\),其中 \(k\) 表示当前节点的子节点个数,且必然 \(≥2\)。于是得证。
对此树求DFS序(注意不是欧拉序,每个节点仅出现一次)。DFS优先拓展重边。此时有性质:一条重链上的所有节点在DFS序中连在一起。

于是,树上的任意一条路径,一定可以被拆分为至多 \(\log{N}\) 个连续区间。我们对每个区间分别求解,最后加在一起,即为答案。
对于DFS序中的元素,常用线段树/树状数组/平衡树等 \(\log{N}\) 级别的数据结构维护。故树链剖分的时间复杂度为 \(O(N\log^2{N})\)。但由于路径穿过的重链往往远不及 \(\log{N}\) 个,且每个重链对应的区间长度也较小,故此算法常数很小,实际效率很高。

模板题
Code

#include
#include
#define ll long long
using namespace std;
const int N = 1e5 + 5;
int n, m, a[N];
int head[N], nxt[N << 1], ver[N << 1], tot;
int id[N], sz[N], idx, dfn[N], d[N], son[N], fa[N], top[N];
struct Tree {
    int l, r;
    ll sum, add;
} t[N << 2];
void Add(int x, int y) {
    nxt[++tot] = head[x]; head[x] = tot; ver[tot] = y;
}
void dfs1(int x) {
    sz[x] = 1;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (d[y]) continue;
        d[y] = d[x] + 1; fa[y] = x;
        dfs1(y);
        sz[x] += sz[y];
        if (sz[y] > sz[son[x]]) son[x] = y;
    }
}
void dfs2(int x) {
    dfn[++idx] = x; id[x] = idx;
    if (!son[x]) return ;
    top[son[x]] = top[x]; dfs2(son[x]);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (d[y] < d[x] || y == son[x]) continue;
        top[y] = y; dfs2(y);
    }
}
void pushup(int p) {
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}
void pushdown(int p) {
    if (!t[p].add) return ;
    Tree &u = t[p], &l = t[p << 1], &r = t[p << 1 | 1];
    l.add += u.add; l.sum += u.add * (l.r - l.l + 1);
    r.add += u.add; r.sum += u.add * (r.r - r.l + 1);
    u.add = 0;
}
void Build(int l, int r, int p) {
    t[p].l = l; t[p].r = r;
    if (l == r) {
        t[p].sum = a[dfn[l]]; return ;
    }
    int mid = l + r >> 1;
    Build(l, mid, p << 1); Build(mid + 1, r, p << 1 | 1);
    pushup(p);
}
void Insert(int l, int r, int v, int p) {
    if (l <= t[p].l && t[p].r <= r) {
        t[p].add += v; t[p].sum += (ll)v * (t[p].r - t[p].l + 1); return ;
    }
    int mid = t[p].l + t[p].r >> 1;
    pushdown(p);
    if (l <= mid) Insert(l, r, v, p << 1);
    if (r > mid) Insert(l, r, v, p << 1 | 1);
    pushup(p);
}
ll Query(int l, int r, int p) {
    if (l <= t[p].l && t[p].r <= r) return t[p].sum;
    int mid = t[p].l + t[p].r >> 1;
    ll res = 0;
    pushdown(p);
    if (l <= mid) res += Query(l, r, p << 1);
    if (r > mid) res += Query(l, r, p << 1 | 1);
    pushup(p);
    return res;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        Add(x, y); Add(y, x);
    }
    d[1] = 1; dfs1(1); top[1] = 1; dfs2(1);
    Build(1, n, 1);
    scanf("%d", &m);
    while (m--) {
        int opt, x, y, z;
        scanf("%d%d", &opt, &x);
        if (opt == 1) {
            scanf("%d%d", &y, &z);
            while (top[x] != top[y]) { //树剖部分
                if (d[top[x]] < d[top[y]]) swap(x, y);
                Insert(id[top[x]], id[x], z, 1);
                x = fa[top[x]];
            }
            if (d[x] < d[y]) swap(x, y);
            Insert(id[y], id[x], z, 1); //剩下一段单独处理
        }
        else if (opt == 2) {
            scanf("%d", &z);
            Insert(id[x], id[x] + sz[x] - 1, z, 1);
        }
        else if (opt == 3) {
            scanf("%d", &y);
            ll res = 0;
            while (top[x] != top[y]) {
                if (d[top[x]] < d[top[y]]) swap(x, y);
                res += Query(id[top[x]], id[x], 1);
                x = fa[top[x]];
            }
            if (d[x] < d[y]) swap(x, y);
            res += Query(id[y], id[x], 1);
            printf("%lld\n", res);
        }
        else printf("%lld\n", Query(id[x], id[x] + sz[x] - 1, 1));
    }
    return 0;
}

对于静态的树来说,树剖还是很不错的算法。

动态树

这就是大名鼎鼎的LCT(Link-Cut Tree)。支持动态地加边、删边,并兼容树链剖分的几乎所有操作。时间复杂度仅为 \(O(N\log{N})\)
实际上LCT维护的是森林,以下以一棵树进行讨论。
与树链剖分类似地,将所有边分别实边虚边,每个节点可以向下连最多一条实边,但也可以不连。实边构成实边链,每个节点包含在一条链中。
接下来就是LCT的关键。对每条实边链,用一个Splay维护,Splay的中序遍历对应链中从上往下的顺序。实边链中最上面的节点不一定是Splay的根节点。
然后,将这些Splay连接起来,构成一棵树,以下称为Splay的树。对于每条实边链,连接其最上面节点(设为 \(x\) )与其父节点(设为 \(y\) )的虚边在Splay的树中连接这条实边链对应Splay的根节点与 \(y\)(连接方法之后讨论)。而如何通过Splay的树对应到原树?实边显然可以通过每个Splay的中序遍历得到,虚边则通过每个Splay中最“左”的点与其根节点向上连到的点得到。
接着讨论如何连接Splay的树。在以下的操作中,要求每个Splay要隔离,但又要求向上追溯到原树的根(为什么?1.LCT维护的实际是森林;2.LCT是无根树,之后的操作会进行“换根”)。我们的处理方法是:每个Splay的根节点向上的连边仅有“向上连”而没有“向下连”,即“认父不认子”。Splay和rotate中也要特殊判断,防止转错。
以下介绍LCT的几大操作。

access(int x)
作用是,在原树中构造出从根到 \(x\) 的实边路径,且不再包含 \(x\) 的子节点。并将 \(x\) splay到其Splay的根。
\(x\) splay到根节点(仅为其splay根节点),找到其父节点 \(y\),将 \(y\) splay到根节点。则此时 \(y\) 一定没有后继。将 \(x\) 为根的子树全部接到 \(y\) 的右儿子。不断循环该操作即可。
makeroot(int x)
\(x\) 变成原树中的根节点。
只需access(x),此时 \(x\) 到根的路径构成一棵splay,将其翻转即可。
findroot(int x)
找到 \(x\) 在原树中的根节点。附属作用,将根节点转到splay中的根。
只需access(x),再一路向左即可。最后记得splay。
split(int x, int y)
\(x,y\) 在原树中连通,则建立一条 \(x\)\(y\) 的实边路径。
makeroot(x),access(y)。
link(int x, int y)
\(x,y\) 不连通,加入边 \((x,y)\)
先makeroot(x)。若findroot(y)!=x,则连接。因为不确定此时 \(x\) 是否有向下的实边,故只能连虚边。因为已经makeroot(x)了,故可以直接将 \(x\) 的父亲设为 \(y\)
cut(int x, int y)
\(x,y\) 之间有边,删除此边。
先makeroot(x)。若 \(x,y\) 有边,则此时 findroot(y)==x。且因为前面的findroot(y),\(y\)\(x\) 一定在同一条实边链上,则 \(y\)\(x\) 的后继。且根据rotate的操作,\(y\) 一定为 \(x\) 的右子节点。于是:

void cut(int x, int y) {
    makeroot(x);
    if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
        t[x].s[1] = t[y].p = 0; pushup(x);
    }
}

至此为LCT的所有操作。第一次写代码时可能很难,但背模板还是可以的。LCT与网络流有些相似,不会在代码的实现上有很多变通,基本只有pushup和pushdown需要注意。
关于LCT的时间复杂度,为 \(O(N\log{N})\),常数较大。目前还不会证明。毕竟Splay就已经很玄学了,这么多个Splay更玄学……

模板题
Code

#include
#include
using namespace std;
const int N = 1e5 + 5;
int n, m, stack[N], top;
struct Node {
    int s[2], p, v, sum, rev;
} t[N];
bool isroot(int x) {
    return t[t[x].p].s[0] != x && t[t[x].p].s[1] != x;
}
void pushup(int x) {
    t[x].sum = t[t[x].s[0]].sum ^ t[t[x].s[1]].sum ^ t[x].v;
}
void pushrev(int x) {
    swap(t[x].s[0], t[x].s[1]); t[x].rev ^= 1;
}
void pushdown(int x) {
    if (!t[x].rev) return ;
    pushrev(t[x].s[0]); pushrev(t[x].s[1]); t[x].rev = 0;
}
void rotate(int x) {
    int y = t[x].p, z = t[y].p, k = t[y].s[1] == x;
    if (!isroot(y)) t[z].s[t[z].s[1] == y] = x; t[x].p = z;
    t[y].s[k] = t[x].s[k ^ 1]; t[t[x].s[k ^ 1]].p = y;
    t[x].s[k ^ 1] = y; t[y].p = x;
    pushup(y); pushup(x);
}
void splay(int x) {
    int p = x;
    stack[++top] = p;
    while (!isroot(p)) stack[++top] = p = t[p].p;
    while (top) pushdown(stack[top--]);
    while (!isroot(x)) {
        int y = t[x].p, z = t[y].p;
        if (!isroot(y))
            if (t[z].s[1] == y ^ t[y].s[1] == x) rotate(x);
            else rotate(y);
        rotate(x);
    }
}
void access(int x) {
    int z = x;
    for (int y = 0; x; y = x, x = t[x].p) {
        splay(x); t[x].s[1] = y; pushup(x);
    }
    splay(z);
}
void makeroot(int x) {
    access(x); pushrev(x);
}
int findroot(int x) {
    access(x);
    while (t[x].s[0]) {
        pushdown(x); x = t[x].s[0];
    }
    splay(x);
    return x;
}
void split(int x, int y) {
    makeroot(x); access(y);
}
void link(int x, int y) {
    makeroot(x);
    if (findroot(y) != x) t[x].p = y;
}
void cut(int x, int y) {
    makeroot(x);
    if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
        t[x].s[1] = t[y].p = 0; pushup(x);
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &t[i].v);
    while (m--) {
        int opt, x, y;
        scanf("%d%d%d", &opt, &x, &y);
        if (opt == 0) {
            split(x, y); printf("%d\n", t[y].sum);
        }
        else if (opt == 1) link(x, y);
        else if (opt == 2) cut(x, y);
        else {
            splay(x); t[x].v = y; pushup(x);
        }
    }
    return 0;
}

例:Acwing999 魔法森林
有一张 \(n\) 个节点 \(m\) 条边的无向图,每条边有两个权值 \(a_i,b_i\)。要求找到一条从 \(1\)\(n\) 的路径,要求路径上 \(a\) 的最大值与 \(b\) 的最大值的和最小。\(1≤n≤50000\)\(0≤m≤100000\)

首先,若 \(a\) 的最大值已经确定(设为 \(\bar{a}\)),则无向图中可以走的边是确定的。问题转化为求解此图中从 \(1\)\(n\) 的路径使最大值最小,以 \(b\) 为权值。可以想到一种类似最小生成树的做法。将边按 \(b\) 从小到大排序,不断加边,直到连通,并查集维护。
可以想到二分。但是,此题中虽然随着 \(\bar{a}\) 的增加图中边数增加,但同时最小生成树的值会减小,最终答案不具有单调性。那么对于所有 \(\bar{a}\),均需更新一次答案。边数为 \(100000\),那么就需要 \(log\) 级别的算法。
那么我们将所有边按 \(a\) 排序,则 \(\bar{a}\) 不断增大,不断向图中加边。我们需要维护 \(1\)\(n\) 的最大值最小路径。不难想到维护最小生成树。加边时,用类似求次小生成树的方法,加入这条边,若构成环,删除环中权值最大的边。
那么,在最小生成树中如何 \(\log{N}\) 维护 \(1\)\(n\) 最大值最小路径呢?可以想到LCT。利用点边转化的技巧,将边转化为点,边权转化为点权,原来点的权值为0,则可以实现。
此题需要卡常,在判断点连通时可以舍弃findroot,使用并查集。

Code

#include
#include
using namespace std;
const int N = 5e4 + 5, M = 1e5 + 5;
int n, m, v[N + M], fa[N];
int stack[N + M], top;
struct Edge {
    int x, y, a, b;
    bool operator <(const Edge &o) const {
        return a < o.a;
    }
} e[M];
struct Node {
    int s[2], p, mx, rev;
} t[N + M];
int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
    return x;
}
int get(int x) {
    if (x == fa[x]) return x;
    return fa[x] = get(fa[x]);
}
bool isroot(int x) {
    return t[t[x].p].s[0] != x && t[t[x].p].s[1] != x;
}
void pushrev(int x) {
    swap(t[x].s[0], t[x].s[1]); t[x].rev ^= 1;
}
void pushup(int x) {
    t[x].mx = x;
    for (int i = 0; i < 2; i++)
        if (v[t[t[x].s[i]].mx] > v[t[x].mx])
            t[x].mx = t[t[x].s[i]].mx;
}
void pushdown(int x) {
    if (!t[x].rev) return ;
    pushrev(t[x].s[0]); pushrev(t[x].s[1]); t[x].rev = 0;
}
void rotate(int x) {
    int y = t[x].p, z = t[y].p, k = t[y].s[1] == x;
    if (!isroot(y)) t[z].s[t[z].s[1] == y] = x; t[x].p = z;
    t[y].s[k] = t[x].s[k ^ 1]; t[t[x].s[k ^ 1]].p = y;
    t[x].s[k ^ 1] = y; t[y].p = x;
    pushup(y); pushup(x);
}
void splay(int x) {
    int p = x;
    stack[++top] = p;
    while (!isroot(p)) stack[++top] = p = t[p].p;
    while (top) pushdown(stack[top--]);
    while (!isroot(x)) {
        int y = t[x].p, z = t[y].p;
        if (!isroot(y))
            if (t[z].s[1] == y ^ t[y].s[1] == x) rotate(x);
            else rotate(y);
        rotate(x);
    }
}
void access(int x) {
    int z = x;
    for (int y = 0; x; y = x, x = t[x].p) {
        splay(x); t[x].s[1] = y; pushup(x);
    }
    splay(z);
}
void makeroot(int x) {
    access(x); pushrev(x);
}
int findroot(int x) {
    access(x);
    while (t[x].s[0]) {
        pushdown(x); x = t[x].s[0];
    }
    splay(x);
    return x;
}
void split(int x, int y) {
    makeroot(x); access(y);
}
void link(int x, int y) {
    makeroot(x);
    if (findroot(y) != x) t[x].p = y;
}
void cut(int x, int y) {
    makeroot(x);
    if (findroot(y) == x && t[x].s[1] == y && !t[y].s[0]) {
        t[x].s[1] = t[y].p = 0; pushup(x);
    }
}
int main() {
    int ans = 0x7fffffff;
    n = read(); m = read();
    for (int i = 1; i <= m; i++) e[i] = (Edge){read(), read(), read(), read()};
    sort(e + 1, e + m + 1);
    for (int i = 1; i <= n + m; i++) {
        t[i].mx = i;
        if (i <= n) fa[i] = i;
        else v[i] = e[i - n].b;
    }
    for (int i = 1; i <= m; i++) {
        int x = e[i].x, y = e[i].y;
        if (get(x) == get(y)) {
            split(x, y);
            if (v[t[y].mx] > e[i].b) {
                int p = t[y].mx;
                cut(e[p - n].x, p); cut(e[p - n].y, p);
                link(x, i + n); link(y, i + n);
            }
        }
        else {
            link(x, i + n); link(y, i + n); fa[get(x)] = get(y);
        }
        if (get(1) == get(n)) {
            split(1, n); ans = min(ans, e[i].a + v[t[n].mx]);
        }
    }
    printf("%d\n", ans == 0x7fffffff ? -1 : ans);
    return 0;
}