最大生成树计数


最大生成树计数

最近做到了一道题目,有关于最大生成树计数的,这里来说一下,求解最大生成树数量的方法。

题意

给你一个 \(n\) 个点 \(m\) 条边的无向图,每个边有权值,求出来这个图里,最大生成树的个数。

答案对于 \(998244353\) 取模。

题解

我们考虑一个事情,就是说,对于给定的图的任意的一个最大生成树,对于一个固定的权值 \(w\),他们对应的权值为 \(w\) 的边的个数相同。

所以,对于每个权值,分别做,计数就好了。

我们考虑已经做了一些大的权值,然后现在要做 \(w\) 权值的边的计数。

我们首先,对于之前的那些大于 \(w\) 的边,会形成一些连通块,我们对于 \(w\) 的边,只做连接两个不同连通块的边所构成的生成树的计数(这里把之前的一个连通块看成了一个点),然后用矩阵树定理得到答案即可。

需要注意的是,我们可能做到 \(w\) 的时候图是不连通的,所以,需要对于每个大于等于 \(w\) 的边的连通块分别运用矩阵树定理求解即可。

这里给出一道例题。

【2022省选十连测 Day 3】treecnt - 题目 - Zhengrui Online Judge (zhengruioi.com)

在这道题目中,我们只要对于两个点 \(x,y\) 之间,连一条 \(w(i,j)\) 的边,其中 \(w(i,j)\) 为有多少个限制里同时包含 \(i\) 点和 \(j\) 点。然后,我们一个符合要求的生成树就是权值为 \(\sum S_i - K\) 的生成树,然后同时这个生成树权值肯定最大,所以,我们对于最大生成树计数即可。

#include 
const int MAXN = 505, MOD = 998244353;
using std::cin;
using std::cout;
using std::bitset;
using std::sort;
using std::vector;
using std::swap;
struct Edge {
	int x, y, w, key;
} e[MAXN * MAXN];
int N, K, ecnt, f1[MAXN], f2[MAXN], tot, id[MAXN], A[MAXN][MAXN];
char s[MAXN];
bitset<2010> bs[MAXN];
vector v[MAXN * MAXN];
int find1(int x) {
	return f1[x] == x ? x : f1[x] = find1(f1[x]);
}
int find2(int x) {
	return f2[x] == x ? x : f2[x] = find2(f2[x]);
}
auto Mod = [] (int x) {
	if (x >= MOD) {
		return x - MOD;
	}
	else if (x < 0) {
		return x + MOD;
	}
	else {
		return x;
	}
};
auto Ksm = [] (int x, int y) -> int {
	int ret = 1;
	for (; y; y >>= 1, x = (long long) x * x % MOD) {
		if (y & 1) {
			ret = (long long) ret * x % MOD;
		}
	}
	return ret;
};
int det(int m) {
	int ret = 1;
	for (int i = 1; i <= m; ++i) {
		for (int j = i; j <= m; ++j) {
			if (A[j][i]) {
				for (int k = i; k <= m; ++k) {
					swap(A[i][k], A[j][k]);
				}
				if (j != i) {
					ret = Mod(-ret);
				}
				break;
			}
		}
		ret = (long long) ret * A[i][i] % MOD;
		int invl = Ksm(A[i][i], MOD - 2);
		for (int j = i + 1; j <= m; ++j) {
			if (A[j][i]) {
				int mul = (long long) invl * A[j][i] % MOD;
				for (int k = i; k <= m; ++k) {
					A[j][k] = Mod(A[j][k] - (long long) mul * A[i][k] % MOD);
				}
			}
		}
	}
	return ret;
}
int calc(int x) {
	int cnt = 0;
	for (auto &i: v[x]) {
		if (!id[find2(i.x)]) {
			id[f2[i.x]] = ++cnt;
		}
		if (!id[find2(i.y)]) {
			id[f2[i.y]] = ++cnt;
		}
		int x = id[f2[i.x]], y = id[f2[i.y]];
		A[x][x] = Mod(A[x][x] + i.w);
		A[y][y] = Mod(A[y][y] + i.w);
		A[x][y] = Mod(A[x][y] - i.w + MOD);
		A[y][x] = Mod(A[y][x] - i.w + MOD);
	}
	for (auto &i: v[x]) {
		id[f2[i.x]] = 0;
		id[f2[i.y]] = 0;
	}
	for (auto &i: v[x]) {
		if (find2(i.x) != find2(i.y)) {
			f2[f2[i.x]] = f2[i.y];
		}
	}
	int ret = det(cnt - 1);
	for (int i = 1; i <= cnt; ++i) {
		for (int j = 1; j <= cnt; ++j) {
			A[i][j] = 0;
		}
	}
	return ret;
}
int main() {
	std::ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> N >> K;
	for (int i = 1, x; i <= N - 1; ++i) {
		for (int j = i + 1; j <= N; ++j) {
			cin >> x;
			e[++ecnt] = {i, j, x, 0};
		}
	}
	int S = 0;
	for (int i = 1; i <= K; ++i) {
		cin >> s + 1;
		for (int j = 1; j <= N; ++j) {
			bs[j].set(i, s[j] == '1');
			S += s[j] == '1';
		}
	}
	for (int i = 1; i <= ecnt; ++i) {
		e[i].key = (bs[e[i].x] & bs[e[i].y]).count();
	}
	for (int i = 1; i <= N; ++i) {
		f1[i] = i;
	}
	sort(e + 1, e + 1 + ecnt, [&] (const Edge &a, const Edge &b) -> int {
		return a.key > b.key;
	});
	int singercoder = 0;
	for (int i = 1; i <= ecnt; ++i) {
		if (find1(e[i].x) != find1(e[i].y)) {
			f1[f1[e[i].x]] = f1[e[i].y];
			singercoder += e[i].key;
		}
	}
	if (singercoder != S - K) {
		cout << 0 << '\n';
		return 0;
	}
	for (int i = 1; i <= N; ++i) {
		f1[i] = f2[i] = i;
	}
	int ANS = 1;
	for (int i = 1, p; i <= ecnt; i = p + 1) {
		p = i;
		while (p < ecnt && e[p + 1].key == e[i].key) {
			++p;
		}
		for (int j = i; j <= p; ++j) {
			if (find1(e[j].x) != find1(e[j].y)) {
				f1[f1[e[j].x]] = f1[e[j].y];
			}
		}
		for (int j = i; j <= p; ++j) {
			if (find2(e[j].x) == find2(e[j].y)) {
				continue;
			}
			if (!id[find1(e[j].x)]) {
				id[find1(e[j].x)] = ++tot;
			}
			v[id[f1[e[j].x]]].push_back(e[j]);
		}
		for (int j = i; j <= p; ++j) {
			id[f1[e[j].x]] = 0;
		}
		for (int j = 1; j <= tot; ++j) {
			ANS = (long long) ANS * calc(j) % MOD;
			v[j].clear();
		}
	}
	cout << ANS << '\n';
	return 0;
}