虚树学习笔记


虚树定义

虚树是一棵虚拟构建的树,这棵树的特点是只包含关键点以及关键的点,这些点满足在原树中的关系,而其他点和边相当于都做了路径压缩

例题引入

luogu2495 [SDOI2011]消耗战

题目描述

一颗树,上面有 k 个资源点,拆一些边,使得 1 号点不能到达任何资源点。现在要使得拆除的边的权值和最小。总共有 m 次询问,每次给出资源点。

题解

题目中的资源点在虚树中就相当于关键点

当我们不考虑虚树时,我们考虑怎么做这个题

我们发现可以树形DP

\(dp[i]\)表示以\(i\)为根的子树中不与关键点连通的最小代价,\(u\)\(i\)的儿子

则有DP方程

\[dp[i]= \begin{cases} dp[i]+min(dp[u],e[i][u])~~~~u不是关键点\\ dp[i]=dp[i]+e[i][u]~~~~~~~~~~~~u是关键点 \end{cases} \]

很显然,这样的复杂度是\(O(nm)\)的,并不符合我们的要求

我们考虑有没有更优的做法或优化

我们重新观察题面

我们发现关键点的总数量只有\(n\)

又发现我们的算法中其实有很多点的子树中并不包含关键点,也就是说根本不需要算它的\(dp\)

因此,我们知道肯定要做出一棵很小的树来快速解决问题

而这棵树中需要储存的东西只有关键点和他们的\(LCA\),也就是一棵虚树

这棵虚树怎么去建呢(图片来自 oi-wiki.org)

对于一个这样的图

1

红色是关键点,红点和黑点都是虚树中的点,黑边是虚树中的边

1

1

1

通过这几张图,我们具象化的了解了虚树的形状,接下来考虑如何建一棵虚树

我们不能\(O(n^2)\)\(LCA\),不难想到可以按照\(dfs\)序排序后求相邻的\(LCA\)

我们知道,对于一棵虚树,只要保证祖先后代关系不变即可随便加点

因此为了方便我们把根节点加进去

然后,我们来做出一个方案建立虚树

我们开一个单调栈,维护虚树上的一条链

如果当前我们要加进去的节点\(now\),与栈顶节点\(top\)\(LCA\)\(top\)就直接入栈

如果不是,则弹栈直到与\(top\)\(LCA\)\(top\)时将\(now\)入栈

当然,在这个过程中不要忘了将栈顶与弹出的节点连边

当我们把全部过程做完后,虚树也就建好了

这时候,我们重新回到那个题

我们处理出树上每个点到根的路径上的最小值,然后直接按照我们原来的方式\(dp\)就行了

#include 
#include 
#include 
#include 
#include 
#include 
#define int long long
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
using namespace std;
const int maxn=5e5+5;
int n,N,m,beg[maxn],tot,Min[maxn],fa[maxn][26],dp[maxn],dfn[maxn],cnt,a[maxn],vis[maxn],st[maxn],top,dep[maxn];
struct edge{
    int nex,to,w;
}e[maxn*2];
void add(int x,int y,int z) {
    e[++tot]=(edge){beg[x],y,z};
    beg[x]=tot;
}
vectorvec[maxn*2];
void chkmax(int &x,int y) {if (xy) x=y;}
int read() {
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while(ch<='9' && ch>='0') {x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
void dfs(int now,int Fa) {
    fa[now][0]=Fa;
    dfn[now]=++cnt;
    dep[now]=dep[Fa]+1;
    for (int i=beg[now];i;i=e[i].nex) {
        int nex=e[i].to;
        if (nex==Fa) continue;
        chkmin(Min[nex],e[i].w);
        chkmin(Min[nex],Min[now]);
        dfs(nex,now);
    }
}
bool cmp(int x,int y) {
    return dfn[x]=0;i--) {
        int fx=fa[x][i];
        if (dep[fx]>=dep[y]) x=fx;
        if (x==y) return x;
    } 
    for (int i=20;i>=0;i--) {
        if (fa[x][i]!=fa[y][i]) {
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    return fa[x][0];
}
void build() {
    sort(a+1,a+1+N,cmp);
    st[top=1]=1;
    for (int i=1;i<=N;i++) {
        int lca=LCA(st[top],a[i]);
        if (lca!=st[top]) {
            while(dfn[lca]