洛谷 「P4219 [BJOI2014]大融合」


传送门

\(\texttt{Description}\)

  • \(n\) 个点的图,\(m\) 次操作。

  • 初始时 \(n\) 个点相互独立。

  • 每次操作:

    • A x y:在 \(x\)\(y\) 之间连一条边(保证 \(x\)\(y\) 不联通)

    • Q x y:询问图上有多少条简单路径经过 \(x\)\(y\) 之间的边(保证 \(x\)\(y\) 之间有边)

\(\texttt{Solution}\)

  • 不难看出这张图是森林。

  • 那么每次询问的答案是什么呢?放一张图。

  • 对于 \(3\)\(6\) 之间的这条边,答案为以 \(6\) 为根的子树大小乘上整棵树的大小减去以 \(6\) 为根的子树的大小。

  • 那么对于 \(x\)\(y\) 之间的边,令 \(x\) 的深度小于 \(y\) 的深度,则答案为 \(siz_y\times(siz_{root}-siz_y)\)\(siz_i\) 为以 \(i\) 为根的子树大小)。

  • 接下来考虑如何维护子树大小。

  • 由于子树大小的信息可以从孩子上传到父亲,所以考虑线段树合并。

  • 先离线操作,预处理出每个结点的 dfs 序。

  • 对于每个节点,维护一棵权值线段树,结点 \([i,j]\) 表示当前结点的联通块内结点 dfs 序范围在 \([i,j]\) 中的数量。

  • 再套个并查集就行了,之后是线段树合并的常规套路。

\(\texttt{Code}\)

#include 
#include 

using namespace std;

inline void Swap (int &x, int &y) {x ^= y ^= x ^= y;}

int n, Q, root[100005];

struct quest {
	char op;
	int u, v;
} q[100005];

int cnt, elst[100005];
struct edge {
	int to, nxt;
} e[200005];
inline void add (int u, int v) {
	e[++cnt].to = v;
	e[cnt].nxt = elst[u];
	elst[u] = cnt;
}

int tim, dep[100005], in[100005], out[100005];
inline void dfs (int u, int fa) {
	dep[u] = dep[fa] + 1;
	in[u] = ++tim;
	for (int i = elst[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if (v == fa) continue;
		dfs(v, u);
	}
	out[u] = tim;
}

struct Segment_tree {
	int siz;
	
	struct node {
		int lson, rson, cnt;
	} t[2000005];
	
	inline int ins (int p, int l, int r, int x) {
		if (!p) p = ++siz;
		t[p].cnt = 1;
		if (l == r) return p;
		int mid = l + r >> 1;
		if (x <= mid) t[p].lson = ins(t[p].lson, l, mid, x);
		else t[p].rson = ins(t[p].rson, mid + 1, r, x);
		return p;
	}
	
	inline void pushup (int p) {t[p].cnt = t[t[p].lson].cnt + t[t[p].rson].cnt;}
	
	inline int merge (int p, int q, int l, int r) {
		if (!p) return q;
		if (!q) return p;
		if (l == r) {t[p].cnt += t[q].cnt; return p;}
		int mid = l + r >> 1;
		t[p].lson = merge(t[p].lson, t[q].lson, l, mid);
		t[p].rson = merge(t[p].rson, t[q].rson, mid + 1, r);
		pushup(p);
		return p;
	}
	
	inline int query (int p, int l, int r, int x, int y) {
		if (x <= l && y >= r) return t[p].cnt;
		int mid = l + r >> 1, res = 0;
		if (x <= mid) res += query(t[p].lson, l, mid, x, y);
		if (y > mid) res += query(t[p].rson, mid + 1, r, x, y);
		return res;
	}
} seg;

int fa[100005];
inline int find (int u) {return fa[u] == u ? u : fa[u] = find(fa[u]);}
inline void merge (int u, int v) {
	int fx = find(u), fy = find(v);
	if (fx != fy) fa[fx] = fy, root[fy] = seg.merge(root[fy], root[fx], 1, n);
}

int main () {
	scanf("%d %d", &n, &Q);
	for (int i = 1;i <= Q; i++) {
		scanf("\n%c %d %d", &q[i].op, &q[i].u, &q[i].v);
		if (q[i].op == 'A') add(q[i].u, q[i].v), add(q[i].v, q[i].u);
	}
	for (int i = 1; i <= n; i++)
		if (!in[i]) dfs(i, 0);
	for (int i = 1; i <= n; i++) fa[i] = i, root[i] = seg.ins(root[i], 1, n, in[i]);
	for (int i = 1; i <= Q; i++) {
		char op = q[i].op;
		int u = q[i].u, v = q[i].v;
		if (op == 'A') merge(u, v);
		else {
			if (dep[u] > dep[v]) Swap(u, v);
			int fx = find(u), fy = find(v);
			printf("%lld\n", 1ll * (seg.t[root[fx]].cnt - seg.query(root[fy], 1, n, in[v], out[v])) * seg.query(root[fy], 1, n, in[v], out[v]));
		}
	}
	return 0;
}