点分治


点分治

核心思想:通过对路径上的点进行分类,分成两类查询,适用于大规模进行路径操作

我们考虑对于一棵树来说,可以把边分成两类,一类是过根节点的边,一类是不过根节点的边

我们考虑第一类比较容易,无论是统计答案还是找到这条边都很容易

但是第二类捏?

我们发现这些不过根节点的边必然是是属于其中的一个子树的,所以这给了我们很大的启示,我们可以递归处理这些边,仿照上面的分类方法,直到全部解决

但是我们挑选根节点是很有讲究的,比如说有一条链,我们挑选边缘上的一个点复杂度就是 \(O(n)\) 但是挑选中间的点的话,我们的复杂度就是 \(O(\log n)\)

此时我们引入树的重心的概念

树的重心也叫树的质心。找到一个点,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡。

我们根据上面的定义,可以很容易知道不会出现一棵子树的大小大于 \(\dfrac n 2\) 原因是如果出现了大于的,我们肯定会在这棵子树里选择一个点作为根节点,必然比原来优

那么如果每次递归的时候都以重心为根,那么一定会很有用,因为它保证了你每次递归都至多有 \(\dfrac n 2\) 个点,同时我们发现对于每棵子树,我们求出重心的时间复杂度也只有 \(siz_v\)

所以总计时间复杂度应该是 \(O(n \log n)\),原因是一个点的贡献次数应该为自己成为重心前的贡献次数,那么最底端 的深度最多为 \(\log n\),这个其实感性理解一下就好

我们用几道例题理解一下

P3806 【模板】点分治1

题目描述:给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。

\(1 \le n \le 10^4~~~1\le k \le 10^7~~~1\le q \le 10^4\)

首先由于没有修改操作,所以我们直接离线下来

中间我们开一个桶,来存当前边的值,找到子树的中心后,然后在过根节点的边上进行操作,统计答案

我们先放出代码,然后好好的理解一下点分治的运转过程

int main() {
    read(n, m);
    rep(i, 1, n - 1) {
        int u, v, w;
        read(u, v, w);
        add(u, v, w);
        add(v, u, w);
    }
    rep(i, 1, m) read(q[i]);
    maxsiz[rt] = 2e9;
    sumsiz = n;
    CalcSize(1, 0);
    dfs(rt, 0);
    rep(i, 1, m) puts(ans[i] ? "AYE" : "NAY");
    return 0;
}

首先是主函数,我们把 \(query\) 数组离线下来,然后初始令 \(maxsiz\)\(\inf\),我们第一次找重心时,树是整棵树,共有 \(n\) 个节点,所以 \(sumsiz\)\(n\) 然后我们去找重心,同时开始分治


接下来我们应该先去看怎么找到重心

inline void CalcSize(int u, int f) {
    siz[u] = 1;
    maxsiz[u] = 0;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == f || vis[v])
            continue;
        CalcSize(v, u);
        maxsiz[u] = max(maxsiz[u], siz[v]);
        siz[u] += siz[v];
    }
    maxsiz[u] = max(maxsiz[u], sumsiz - siz[u]);
    if (maxsiz[u] < maxsiz[rt])
        rt = u;
}

也有人把函数名写成 \(get\)_\(zx\) 的,由于本校出现了类似名字的学长,我决定换一个函数名

这段函数的目的就是求 \(siz\),在递归过程中,找到以这个点为根的子树的大小,那么的当前点的 \(maxsiz\) 就是 \(\max(maxsize_u, sumsize-siz_u)\),原因是对于整棵子树来说,应该长成这样:

这样看就理解原理了,我们的 \(maxsiz_u\) 应该是橙色,绿色,紫色三个圈最大的那个,橙色和紫色的比较体现在

maxsiz[u] = max(maxsiz[u], siz[v]);

橙色和紫色的比较体现在

maxsiz[u] = max(maxsiz[u], sumsiz - siz[u]);

以上就是我们的求重心函数


inline void GetDis(int u, int f) {
    tmp.push_back(dis[u]);
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].val;
        if (v == f || vis[v])
            continue;
        dis[v] = dis[u] + w;
        GetDis(v, u);
    }
}

这个函数是找到所有过 \(u\) 点的所有点到点 \(u\) 的距离,不需要太多解释


inline void Calc(int u, int f) {
    exist[0] = vis[u] = true;
    tag.push(0);
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].val;
        if (v == f || vis[v])
            continue;
        dis[v] = w;
        GetDis(v, u);
        int sizz = tmp.size() - 1;
        rep(j, 0, sizz)
            rep(k, 1, m)
                if (q[k] >= tmp[j])
                    ans[k] |= exist[q[k] - tmp[j]];
        rep(j, 0, sizz) {
            if (tmp[j] < K) {
                tag.push(tmp[j]);
                exist[tmp[j]] = true;
            }
        }
        tmp.clear();
    }
    while (!tag.empty())
        exist[tag.front()] = false, tag.pop();
}

这个函数是这个题所特有的,也就是求解点分治以后我们还需要做什么。

这个题要求我们寻找是否有两点之间距离为 \(k\) 的数

我们考虑首先我们开个 \(bool\) 类型的桶 \(exist\),记录现在已经有距离为 \(dis_v\) 的数了

然后我们考虑,我们对于这棵树来说,距离为 \(0\) 的节点必然是有的(因为根节点到自己的距离必然为 \(0\)),接下来我们考虑,如果这条路径是来自两条链拼在一起怎么办

我们会想到,对于这棵树来说,我们不妨每求完一条链,就把这条链上的信息放到桶里,这样对于一个到根节点距离为 \(dis\) 的点,我们会考虑枚举每一个 \(query\) 然后去查找 \(query - dis\) 是否存在

这里实现的几个细节:
首先是我们应该先等到每一个点查找 \(query\) 后,再把这个点放到 \(exist\) 里,原因很简单,要不然你这玩意就提前进去了,会导致判断 \(exist\) 时出问题

第二个就是当 \(tmp>K\) 时就憋往 \(exist\) 里面塞了, RE了又得骂评测姬了

第三个就是我们要在扫描完一棵子树后清空 \(tmp\)

第四个就是我们我们扫描完没棵子树的所有节点后,要用一个队列把每一个 \(tmp\) 存下来,然后可以直接根据这个清空 \(exist\),要不然每次光 \(10^7\) 大小的 \(exist\) 数组的 memset 都肯定受不了


最后就是人们所说的点分治的精髓,也就是 \(solve\) 函数,这里我写成 \(dfs\)

inline void dfs(int u, int f) {
    Calc(u, f);
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == f || vis[v])
            continue;
        sumsiz = siz[v];
        maxsiz[rt = 0] = 2e9;
        CalcSize(v, u), dfs(rt, 0);
    }
}

过程:

统计以 \(u\) 为根的答案 -> 枚举每一个子树,对每一棵子树求重心 -> 对每一棵子树进行 \(dfs\)

这就是我们的点分治模板题了,时间复杂度为点分治的复杂度 \(O(n\log n)\) 乘上 \(O(m)\),总计为 \(O(nm\log n)\)


P2634 聪聪可可

题目描述:找到树上距离为 \(3\) 的倍数的边的数量

\(1 \le n \le 2\times 10^4\)

我们考虑,很少会有人第一眼会直接看出点分治吧,反正我第一眼树形DP,而且非常好写,我们设 \(f_{i,0/1/2}\) 表示以 \(i\) 为根的子树,到 \(i\) 的余数分别为 \(0/1/2\) 的边数,转移方程就很显然了

inline void dfs(int u, int fa) {
    f[u][0] = 1;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].val % 3;
        if (v == fa) continue;
        dfs(v, u);
        rep (j, 0, 2) 
            ans += f[v][j] * f[u][(6 - j - w) % 3] * 2;
        rep (j, 0, 2)
            f[u][(w + j) % 3] += f[v][j];
    }
}

我们返回来考虑,这道题有一个好的性质,就是我从 \(a\)\(b\) 如果 合法,那么从 \(b\)\(a\) 也必然合法

那么我们考虑点分治,我们仍然是开一个桶记录当前已经有多少个 \(\bmod 3\)\(0/1/2\) 的数,仿照上一道题的方法,改一改 \(\mathrm{Calc}\) 函数即可,但是有个坑点,就是两个人如果选上了同一个点的话也是合法的,而且选择是有序的,所以最后的 \(ans\) 要先乘上2再加上 \(n\),具体实现可以看代码

inline void CalcSize(int u, int f) {
    size_tree[u] = 1, maxsize[u] = 0;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == f || vis[v]) continue;
        CalcSize(v, u);
        size_tree[u] += size_tree[v];
        maxsize[u] = max(maxsize[u], size_tree[v]);
    }
    maxsize[u] = max(sumsize - size_tree[u], maxsize[u]);
    if (maxsize[u] < maxsize[rt]) rt = u;
}

inline void getdis(int u, int f) {
    tmp.push_back(dis[u] % 3);
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].val;
        if (v == f || vis[v]) continue;
        dis[v] = (dis[u] + w) % 3;
        getdis(v, u);
    }
}

inline void Calc(int u, int f) {
    vis[u] = 1, ins[0] = 1;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].val;  
        if (v == f || vis[v]) continue;
        dis[v] = w % 3;
        getdis(v, u);
        int siz = tmp.size() - 1;
        rep (j, 0, siz) 
            ans += ins[(3 - tmp[j]) % 3];
        rep (j, 0, siz)
            ins[tmp[j]] ++;
        tmp.clear();
    }
    ins[0] = ins[1] = ins[2] = 0;
}

inline void solve(int u, int f) {
    Calc(u, f);
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == f || vis[v]) continue;
        sumsize = size_tree[v];
        maxsize[rt = 0] = 2e9;
        CalcSize(v, u), solve(rt, u);
    }
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("1.in", "r", stdin);
    freopen("1.out", "w", stdout);
#endif
    read(n);
    rep (i, 1, n - 1) {
        int u, v, w;
        read(u, v, w);
        add (u, v, w);
        add (v, u, w);
    }
    sumsize = n;
    maxsize[rt] = 2e9;
    CalcSize(1, 0);
    solve(rt, 0);
    // write(ans, '\n');
    ans = ans * 2 + n;
    int g = __gcd(ans, n * n);
    write(ans / g, '/'), write(n * n / g, '\n');
    return 0;
}