dsu on tree
dsu on tree
流程
- 首先将整棵树进行重链剖分。
- 对于结点 \(u\),我们遍历其轻儿子,并计算答案,但是不保留这些结点的影响。
- 遍历其重儿子,保留其影响。
- 再次遍历 \(u\) 的轻儿子,加入这些结点的贡献,就可以得到 \(u\) 的答案。
复杂度证明
对于一棵 \(n\) 个结点的树,根结点到树上任意结点的轻边数不超过 \(\log n\) 条。证明如下:
设点 \(u\) 到根结点 \(root\) 有 \(x\) 条轻边数,且 \(u\) 的子树大小为 \(size_u\),显然轻边连接的子结点的子树大小小于等于其父亲的一半,即若 \(x\) 是 \(fa_x\) 的轻儿子,则有
\[size_x \leq \frac{1}{2}size_{fa_x} \] \[size_u \leq \frac{n}{2^x} \]显然 \(n \geq 2^x\),故有 \(x \leq \log n\)。
又因为如果一个结点是其父亲的重儿子,则他的子树必定在他的兄弟之中最多,所以任意结点到根的路径上所有重边连接的父节点在计算答案时必定不会遍历到这个结点,所以一个节点的被遍历的次数等于它到根节点路径上的轻边数 \(+1\)(他本身也要被遍历到),所以一个节点的被遍历次数为 \(\log n+1\),故总时间复杂度为 \(\mathcal O(n \log n)\)。
CF375D
#include
const int N = 1e6 + 10;
const int M = 2e6 + 10;
inline int read()
{
int cnt = 0; char ch = getchar(); bool op = 1;
for (; ! isdigit(ch); ch = getchar())
if (ch == '-') op = 0;
for (; isdigit(ch); ch = getchar())
cnt = cnt * 10 + ch - 48;
return op ? cnt : - cnt;
}
int n, m, col[N], ans[N];
int nxt[M], head[N], to[M], tot;
inline void add(int u, int v)
{
nxt[++ tot] = head[u];
head[u] = tot;
to[tot] = v;
}
int siz[N], son[N];
inline void dfs1(int u, int fa)
{
siz[u] = 1;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == fa) continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[son[u]] < siz[v]) son[u] = v;
}
}
struct query
{
int k, id;
};
std::vector < query > a[N];
int sumcol[N], supersum[N];
int vis[N];
inline void addson(int u, int fa, int val)
{
if (val == -1)
supersum[sumcol[col[u]]] --;
sumcol[col[u]] += val;
if (val == 1)
supersum[sumcol[col[u]]] ++;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == fa || vis[v]) continue;
addson(v, u, val);
}
}
inline void dfs2(int u, int fa, int op)
{
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == fa || v == son[u]) continue;
dfs2(v, u, 1);
}
if (son[u]) dfs2(son[u], u, 0), vis[son[u]] = 1;
addson(u, fa, 1);
for (int i = 0; i < a[u].size(); ++ i)
ans[a[u][i].id] = supersum[a[u][i].k];
if (son[u]) vis[son[u]] = 0;
if (op) addson(u, fa, -1);
}
int main()
{
n = read(), m = read();
for (int i = 1; i <= n; ++ i)
col[i] = read();
for (int i = 1; i < n; ++ i)
{
int u, v;
u = read(), v = read();
add(u, v); add(v, u);
}
for (int i = 1; i <= m; ++ i)
{
int u, k;
u = read(), k = read();
query t; t.id = i, t.k = k;
a[u].push_back(t);
}
dfs1(1, 0); dfs2(1, 0, 0);
for (int i = 1; i <= m; ++ i)
printf("%d\n", ans[i]);
return 0;
}
Reference
OI-wiki