【YBT2022寒假Day8 A】染色计划(Tarjan)(线段树优化建边)(树链剖分)


染色计划

题目链接:YBT2022寒假Day8 A

题目大意

给你一棵树,然后有 k 中颜色,每个点有一个颜色,然后问你要修改多少次,才能使得一个存在的颜色的所有点构成一个连通块。
一个修改操作指选择一个颜色把这个颜色的所有点的颜色改成另一个颜色。

思路

我们考虑如果选一个颜色,我们可以把这个颜色的点弄一棵树,然后路径上时一些别的颜色点构成的链。
那我们每条链上出现过的颜色都要选。

我们考虑把这个选了 \(i\) 就要选 \(j\) 的关系看做是连边,如果我们建好了图,那我们就可以跑个 Tarjan,然后找一个点使得它出发能走的距离最短。
那显然就是肯定只需要看初度为 \(0\) 的点。

但是现在有个问题就是边的数量太多,可能会达到 \(O(n^2)\) 的级别。
所以我们考虑优化建边。

怎么优化呢?
我们先考虑怎么找这些颜色,这些链要拿出来好搞(树链剖分啊,LCT啊都是可以的),但是你不好直接判断出有哪些颜色,就只能一个一个枚举过去。

然后你要知道一个东西叫做线段树优化建图。
什么东西呢,如果一个点跟一个连续的部分每个点都连了边,我们可以把它弄到线段树上,可以连表示一个区间的边。
这样点数和边数都会在 \(n\log n\) 的级别。

然后直接在这些点上面跑 Tarjan,就只需要把线段树上的父节点连接到它的两个儿子即可。
但是你要表示的是颜色的啊,你这个是点啊。

那你考虑再连边,考虑用一种方法可以把相同颜色的点合起来。
可以线段树底端连向它对应的颜色,然后每个颜色就连向那些链在线段树上分割的区间。
然后搞就好了。

代码

#include
#include
#include
#include
#define N 200020

using namespace std;

int n, k, x, y;
int c[N];
vector  e0[N], cp[N];

int dfn[N], id[N], sz[N], deg[N], fa[N], top[N], son[N], tmp;

void dfs0(int now, int father) {
	fa[now] = father; sz[now] = 1;
	deg[now] = deg[father] + 1;
	for (int i = 0; i < e0[now].size(); i++) {
		int to = e0[now][i];
		if (to == father) continue;
		dfs0(to, now); sz[now] += sz[to];
		if (sz[to] > sz[son[now]]) son[now] = to;
	}
}

void dfs1(int now, int father) {
	dfn[++tmp] = now; id[now] = tmp;
	if (son[now]) {
		top[son[now]] = top[now]; dfs1(son[now], now);
	}
	for (int i = 0; i < e0[now].size(); i++) {
		int to = e0[now][i];
		if (to == father || to == son[now]) continue;
		top[to] = to; dfs1(to, now);
	}
}

int tot;
vector  G[N << 3];

struct XD_tree {
	int idd[N << 2];
	
	void build(int now, int l, int r) {
		idd[now] = ++tot;
		if (l == r) {
			G[idd[now]].push_back(c[dfn[l]]); return ;
		}
		int mid = (l + r) >> 1;
		build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
		G[idd[now]].push_back(idd[now << 1]); G[idd[now]].push_back(idd[now << 1 | 1]);
	}
	
	void insert(int now, int l, int r, int L, int R, int va) {
		if (L <= l && r <= R) {
			G[va].push_back(idd[now]);
			return ;
		}
		int mid = (l + r) >> 1;
		if (L <= mid) insert(now << 1, l, mid, L, R, va);
		if (mid < R) insert(now << 1 | 1, mid + 1, r, L, R, va);
	}
	
	void add(int x, int y, int va) {
		while (top[x] != top[y]) {
			if (deg[top[x]] < deg[top[y]]) swap(x, y);
			insert(1, 1, n, id[top[x]], id[x], va);
			x = fa[top[x]];
		}
		if (deg[x] < deg[y]) swap(x, y);
		insert(1, 1, n, id[y], id[x], va);
	}
}T;

int dfm[N << 3], low[N << 3], col[N << 3], f[N << 3], sta[N << 3], cnt;

void tarjan(int now) {
	dfm[now] = low[now] = ++dfm[0];
	sta[++sta[0]] = now;
	for (int i = 0; i < G[now].size(); i++) {
		int x = G[now][i];
		if (!dfm[x]) tarjan(x), low[now] = min(low[now], low[x]);
			else if (!col[x]) low[now] = min(low[now], dfm[x]);
	}
	if (low[now] == dfm[now]) {
		col[now] = ++cnt; f[cnt] = (now <= k);
		while (sta[sta[0]] != now) {
			col[sta[sta[0]]] = cnt; f[cnt] += (sta[sta[0]] <= k);
			sta[0]--;
		}
		sta[0]--;
	}
}

int out[N << 3];

int main() {
//	freopen("color.in", "r", stdin);
//	freopen("color.out", "w", stdout);
	
	scanf("%d %d", &n, &k);
	for (int i = 1; i < n; i++) {
		scanf("%d %d", &x, &y);
		e0[x].push_back(y); e0[y].push_back(x);
	}
	for (int i = 1; i <= n; i++) {
		scanf("%d", &c[i]);
		cp[c[i]].push_back(i);
	}
	
	dfs0(1, 0);
	top[1] = 1; dfs1(1, 0);
	
	tot = k; T.build(1, 1, n);
	for (int i = 1; i <= k; i++) {
		sort(cp[i].begin(), cp[i].end());
		for (int j = 0; j < cp[i].size() - 1; j++)
			T.add(cp[i][j], cp[i][j + 1], i);
	}
	
	for (int i = 1; i <= tot; i++)
		if (!dfm[i]) tarjan(i);
	for (int i = 1; i <= tot; i++)
		for (int j = 0; j < G[i].size(); j++) {
			int to = G[i][j];
			if (col[i] != col[to]) out[col[i]]++;
		}
	
	int ans = k;
	for (int i = 1; i <= cnt; i++)
		if (!out[i]) ans = min(ans, f[i]);
	printf("%d", ans - 1);
	
	return 0;
}