虚树学习笔记


虚树学习笔记

问题的引入

在树上 DP 的问题中,可能有多次询问,每次询问包括的总点数规模较小(例如 \(10^5\))。我们记节点数为 \(n\),询问次数为 \(m\),询问中总点数为 \(\sum k\),那么直接在整棵树上暴力 DP 的复杂度为 \(\mathcal{O}(nm)\),不可接受。能不能发明一种 DP 的方法,不需要访问所有节点,只访问 \(\mathcal{O}(k)\) 个节点呢?这样时间复杂度就优化成了 \(\mathcal{O}(m+\sum k)\)

这种方法当然是有的,就是虚树。

虚树的概念

我们称单次询问中涉及到的节点为关键节点,它们两两的 LCA 为关键 LCA。

虚树是我们构建的一棵外向树,它只包含所有关键节点、关键 LCA 和树根,其它的不重要的节点都被压没了,因为它们对答案没有贡献,递归地计算它们只会浪费时间。

举个例子,对于下面这棵树:

如果关键节点为 \(\{2,4,8,9\}\),则虚树(外向树)长这个样子:

容易发现虚树的点数不超过 \(2k\)(证明考虑每次至少合并两个点,直到合并到根)。

虚树的构建

我们需要预处理出整棵树的 dfs 序时间戳,节点 \(u\) 的时间戳记为 \({dfn}_u\)

我们使用一个栈来暂存树链,栈底为根,栈顶为当前枚举到的树链底端。由于需要访问次顶端(也就是顶端下面的元素),STL 的 std::stack 不那么方便,我们使用手写栈。\(stk\) 为栈,从下标 \(1\) 开始存,\(top\) 为栈顶指针,指向栈顶元素(而不是栈顶元素的后一个)。

我们先对所有关键点按照 \(dfn\) 升序排序,然后把它们依次插入虚树。下面考虑怎么把一个点插入虚树:

  1. 如果栈中只有一个元素即根节点,我们延长这条树链,即 stk[++top] = u;
  2. \(lca\)\(u\)\({stk}_{top}\) 的 LCA,如果 \(lca={stk}_{top}\),就意味着 \(u\)\({stk}_{top}\) 的后代,我们延长这条树链,即 stk[++top] = u;
  3. 如果 \(lca\ne{stk}_{top}\),就意味着 \(u\)\({stk}_{top}\) 属于它们 LCA 的两棵子树,并且栈中这棵子树已经构建完毕,我们需要把 LCA 包含的栈中树链退栈并完成虚树建边,为了虚树结构的完整性,如果 LCA 不在栈中则需要压栈,然后把当前节点压栈,延长树链。

这部分代码如下:

void insert(ll u) {
    if(top == 1) {stk[++top] = u; return;}
    ll lca = LCA(u, stk[top]);
    if(lca == stk[top]) {stk[++top] = u; return;}
    while(top > 1 && dfn[lca] <= dfn[stk[top-1]]) {
        vg[stk[top-1]].push_back(stk[top]);
        --top;
    }
    if(lca != stk[top]) {
        vg[lca].push_back(stk[top]);
        stk[top] = lca;
    }
    stk[++top] = u;
}

例题 洛谷 P2495 [SDOI2011]消耗战

题意

给定一棵 \(n\) 点的有边权树,\(m\) 次询问,每次给定 \(k\) 个点,查询要使得这 \(k\) 个点均不与 \(1\) 连通需要切断的边的最小边权和。

题解

\({mn}_u\) 表示节点 \(u\) 到根的路径的最小边权,即 \({mn}_u=\min\limits_{\{i:\textrm{path from }u\textrm{ to }1\}}w_i\)

\({dp}_u\) 表示 \(u\) 子树内的关键点不与 \(u\) 连通需要切断的最小边权和,显然答案为 \({dp}_1\)

容易得到转移方程:

\[ {dp}_u=\sum\limits_{\{v:\textrm{son of }u\}} \begin{cases} {mn}_v,&v\textrm{ is key vertex}\\ \min({mn}_v,{dp}_v)&\textrm{otherwise}\\ \end{cases} \]

发现转移只与是否是关键节点有关,每次询问建出虚树然后 DP 即可。

注意虚树清空时不能只清关键节点的出边,因为还有关键 LCA,我一开始就是没清完结果造出了重复边,可以再 dfs 一遍清空。

代码如下:

//By: Luogu@rui_er(122461)
#include 
#define rep(x,y,z) for(ll x=y;x<=z;x++)
#define per(x,y,z) for(ll x=y;x>=z;x--)
#define debug printf("Running %s on line %d...\n",__FUNCTION__,__LINE__)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
using namespace std;
typedef long long ll;
const ll N = 2.5e5+5; 

ll n, m, k, h[N], dfn[N], tms, fa[N][20], dis[N], mn[N], stk[N], top, tag[N];
vector > e[N];
vector vg[N];
template void chkmin(T& x, T y) {if(x > y) x = y;}
template void chkmax(T& x, T y) {if(x < y) x = y;}
void dfs1(ll u, ll f) {
    dfn[u] = ++tms;
    dis[u] = dis[f] + 1;
    fa[u][0] = f;
    rep(i, 1, 19) fa[u][i] = fa[fa[u][i-1]][i-1];
    for(auto i : e[u]) {
        ll v = get<0>(i), w = get<1>(i);
        if(v == f) continue;
        mn[v] = w;
        if(u != 1) chkmin(mn[v], mn[u]);
        dfs1(v, u);
    }
}
ll LCA(ll u, ll v) {
    if(dis[u] < dis[v]) swap(u, v);
    per(i, 19, 0) {
        if(dis[fa[u][i]] >= dis[v]) {
            u = fa[u][i];
        }
    }
    if(u == v) return u;
    per(i, 19, 0) {
        if(fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}
void insert(ll u) {
    if(top == 1) {stk[++top] = u; return;}
    ll lca = LCA(u, stk[top]);
    if(lca == stk[top]) {stk[++top] = u; return;}
    while(top > 1 && dfn[lca] <= dfn[stk[top-1]]) {
        vg[stk[top-1]].push_back(stk[top]);
        --top;
    }
    if(lca != stk[top]) {
        vg[lca].push_back(stk[top]);
        stk[top] = lca;
    }
    stk[++top] = u;
}
ll dfs2(ll u) {
    if(tag[u]) return mn[u];
    ll cost = 0;
    for(auto v : vg[u]) cost += min(mn[v], dfs2(v));
    return cost;
}
void dfsClear(ll u) {
    for(auto v : vg[u]) dfsClear(v);
    vg[u].clear();
}

int main() {
    scanf("%lld", &n);
    rep(i, 1, n-1) {
        ll u, v, w;
        scanf("%lld%lld%lld", &u, &v, &w);
        e[u].push_back(make_tuple(v, w));
        e[v].push_back(make_tuple(u, w));
    }
    dfs1(1, 0);
    for(scanf("%lld", &m);m;m--) {
        scanf("%lld", &k);
        rep(i, 1, k) {
            scanf("%lld", &h[i]);
            tag[h[i]] = 1;
        }
        sort(h+1, h+1+k, [&](ll a, ll b) {
            return dfn[a] < dfn[b];
        });
        stk[top=1] = 1;
        rep(i, 1, k) insert(h[i]);
        while(top > 1) {
            vg[stk[top-1]].push_back(stk[top]);
            --top;
        }
//      for(auto v : vg[1]) printf("1 -> %lld\n", v);
//      rep(u, 1, k) for(auto v : vg[h[u]]) printf("%lld -> %lld\n", h[u], v);
        printf("%lld\n", dfs2(1));
        rep(i, 1, k) tag[h[i]] = 0;
        dfsClear(1);
    }
    return 0;
}