树上启发式合并(dsu on tree)精巧的暴力


树上启发式合并(dsu on tree)

虽然叫dsu但这和并查集貌似没什么关系

例:

给你一棵树,每个节点有一个颜色,要求出每个子树中数量最多的颜色并输出

(数量相同的情况先不考虑 不重要

当我们需要在每个子树上统计一些信息的时候,往往会开一个全局的cnt数组,试图 dfs \(O(n)\) 扫一遍,一边加点一边得到答案

但对于一棵树而言显然有问题:当我们统计完其左子树的信息后,必须清空整个cnt数组才能去扫右子树,这样其实就已经变成 \(O(n^2)\)

当然我们可以稍微偷工减料一点,因为最后一棵子树统计完后不用清空,我们可以最后遍历最大的那棵子树

最大子树可以通过一遍dfs预处理出子树的size,记录每个点的重儿子得到(类似树剖)

然而就是这一点偷工减料,使得整个算法复杂度直接降至 \(O(nlogn)\)

如果不关心证明的话,你已经学会 dsu on tree 了

证明为什么这样瞎搞就能获得\(nlogn\)?的复杂度

以下通过感性理解的方式说明为什么这东西能优化这么多

回顾一下在每个节点处我们要做什么:

  • dfs轻儿子,并消除影响
  • dfs重儿子,不消除影响
  • 再统计轻子树的影响

前两步的操作一共是 \(O(n)\) 的,就是最朴素的从头到尾扫一遍

现在需要考虑:在每个点处对每个轻子树扫一遍的复杂度

如果一个点和根节点之间一共有 x 条轻边,那么它会被遍历差不多 x+1 次

而轻重链剖分有个很好的性质:走一条轻边时,节点数量至少被砍一半,否则这就不是轻边了

那么从根节点到任意节点经过的轻边数量最多是 \(logn\)? 级别的

所以其实很显然了:复杂度就是 \(O(nlogn)\)

再看看极端情况加深理解:

树上问题最容易被出题人的各种链,菊花图,链加菊花图啥的卡掉

如果这棵树长得像链,它将被最后走最大子树这一小贪心优化掉一大半;

如果这棵树长得像菊花图,,那么根节点到任意节点间的轻边数量都将是极少的;

所以你可以相信dsu on tree

代码(这道题的)

int n;
int col[maxn];
int cnt[maxn];
ll ans[maxn];
int siz[maxn], son[maxn];

struct Edge{
    int t, nt;
}e[maxn*2];

int hd[maxn], ecnt = 0;

inline void add(int x, int y){
    e[++ecnt].t = y;
    e[ecnt].nt = hd[x];
    hd[x] = ecnt;
}

void dfs1(int p, int fa){
    siz[p] = 1;
    son[p] = 0;
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa){
            dfs1(v, p);
            siz[p] += siz[v];
            if(siz[v] > siz[son[p]]) son[p] = v;
        }
    }
}

ll tot = 0, mxc = 0;
void addcol(int c, int ad){//只计加不计减(减肯定减到0)
    cnt[c] += ad;
    if(cnt[c] > mxc){
        mxc = cnt[c];
        tot = c;
    }else if(cnt[c] == mxc){
        tot += c;
    }
}

void cntall(int p, int fa, int d){
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa){
            cntall(v, p, d);
        }
    }
    addcol(col[p], d);
}

void dfs(int p, int fa, int sav){
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa && v!=son[p]){
            dfs(v, p, 0);
        }
    }
    if(son[p]) dfs(son[p], p, 1);
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa && v!=son[p]){
            cntall(v, p, 1);
        }
    }
    //此时所有子节点均已记录
    addcol(col[p], 1);
    ans[p] = tot;
    if(!sav) cntall(p, fa, -1), tot = mxc = 0;
}

void solve(){
    cin >> n;
    for(int i=1;i<=n;i++) cin >> col[i];
    for(int i=1;i> x >> y;
        add(x, y); add(y, x);
    }
    dfs1(1, -1);
    dfs(1, -1, 1);
    for(int i=1;i