最大生成树计数
最大生成树计数
最近做到了一道题目,有关于最大生成树计数的,这里来说一下,求解最大生成树数量的方法。
题意
给你一个 \(n\) 个点 \(m\) 条边的无向图,每个边有权值,求出来这个图里,最大生成树的个数。
答案对于 \(998244353\) 取模。
题解
我们考虑一个事情,就是说,对于给定的图的任意的一个最大生成树,对于一个固定的权值 \(w\),他们对应的权值为 \(w\) 的边的个数相同。
所以,对于每个权值,分别做,计数就好了。
我们考虑已经做了一些大的权值,然后现在要做 \(w\) 权值的边的计数。
我们首先,对于之前的那些大于 \(w\) 的边,会形成一些连通块,我们对于 \(w\) 的边,只做连接两个不同连通块的边所构成的生成树的计数(这里把之前的一个连通块看成了一个点),然后用矩阵树定理得到答案即可。
需要注意的是,我们可能做到 \(w\) 的时候图是不连通的,所以,需要对于每个大于等于 \(w\) 的边的连通块分别运用矩阵树定理求解即可。
这里给出一道例题。
【2022省选十连测 Day 3】treecnt - 题目 - Zhengrui Online Judge (zhengruioi.com)
在这道题目中,我们只要对于两个点 \(x,y\) 之间,连一条 \(w(i,j)\) 的边,其中 \(w(i,j)\) 为有多少个限制里同时包含 \(i\) 点和 \(j\) 点。然后,我们一个符合要求的生成树就是权值为 \(\sum S_i - K\) 的生成树,然后同时这个生成树权值肯定最大,所以,我们对于最大生成树计数即可。
#include
const int MAXN = 505, MOD = 998244353;
using std::cin;
using std::cout;
using std::bitset;
using std::sort;
using std::vector;
using std::swap;
struct Edge {
int x, y, w, key;
} e[MAXN * MAXN];
int N, K, ecnt, f1[MAXN], f2[MAXN], tot, id[MAXN], A[MAXN][MAXN];
char s[MAXN];
bitset<2010> bs[MAXN];
vector v[MAXN * MAXN];
int find1(int x) {
return f1[x] == x ? x : f1[x] = find1(f1[x]);
}
int find2(int x) {
return f2[x] == x ? x : f2[x] = find2(f2[x]);
}
auto Mod = [] (int x) {
if (x >= MOD) {
return x - MOD;
}
else if (x < 0) {
return x + MOD;
}
else {
return x;
}
};
auto Ksm = [] (int x, int y) -> int {
int ret = 1;
for (; y; y >>= 1, x = (long long) x * x % MOD) {
if (y & 1) {
ret = (long long) ret * x % MOD;
}
}
return ret;
};
int det(int m) {
int ret = 1;
for (int i = 1; i <= m; ++i) {
for (int j = i; j <= m; ++j) {
if (A[j][i]) {
for (int k = i; k <= m; ++k) {
swap(A[i][k], A[j][k]);
}
if (j != i) {
ret = Mod(-ret);
}
break;
}
}
ret = (long long) ret * A[i][i] % MOD;
int invl = Ksm(A[i][i], MOD - 2);
for (int j = i + 1; j <= m; ++j) {
if (A[j][i]) {
int mul = (long long) invl * A[j][i] % MOD;
for (int k = i; k <= m; ++k) {
A[j][k] = Mod(A[j][k] - (long long) mul * A[i][k] % MOD);
}
}
}
}
return ret;
}
int calc(int x) {
int cnt = 0;
for (auto &i: v[x]) {
if (!id[find2(i.x)]) {
id[f2[i.x]] = ++cnt;
}
if (!id[find2(i.y)]) {
id[f2[i.y]] = ++cnt;
}
int x = id[f2[i.x]], y = id[f2[i.y]];
A[x][x] = Mod(A[x][x] + i.w);
A[y][y] = Mod(A[y][y] + i.w);
A[x][y] = Mod(A[x][y] - i.w + MOD);
A[y][x] = Mod(A[y][x] - i.w + MOD);
}
for (auto &i: v[x]) {
id[f2[i.x]] = 0;
id[f2[i.y]] = 0;
}
for (auto &i: v[x]) {
if (find2(i.x) != find2(i.y)) {
f2[f2[i.x]] = f2[i.y];
}
}
int ret = det(cnt - 1);
for (int i = 1; i <= cnt; ++i) {
for (int j = 1; j <= cnt; ++j) {
A[i][j] = 0;
}
}
return ret;
}
int main() {
std::ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> N >> K;
for (int i = 1, x; i <= N - 1; ++i) {
for (int j = i + 1; j <= N; ++j) {
cin >> x;
e[++ecnt] = {i, j, x, 0};
}
}
int S = 0;
for (int i = 1; i <= K; ++i) {
cin >> s + 1;
for (int j = 1; j <= N; ++j) {
bs[j].set(i, s[j] == '1');
S += s[j] == '1';
}
}
for (int i = 1; i <= ecnt; ++i) {
e[i].key = (bs[e[i].x] & bs[e[i].y]).count();
}
for (int i = 1; i <= N; ++i) {
f1[i] = i;
}
sort(e + 1, e + 1 + ecnt, [&] (const Edge &a, const Edge &b) -> int {
return a.key > b.key;
});
int singercoder = 0;
for (int i = 1; i <= ecnt; ++i) {
if (find1(e[i].x) != find1(e[i].y)) {
f1[f1[e[i].x]] = f1[e[i].y];
singercoder += e[i].key;
}
}
if (singercoder != S - K) {
cout << 0 << '\n';
return 0;
}
for (int i = 1; i <= N; ++i) {
f1[i] = f2[i] = i;
}
int ANS = 1;
for (int i = 1, p; i <= ecnt; i = p + 1) {
p = i;
while (p < ecnt && e[p + 1].key == e[i].key) {
++p;
}
for (int j = i; j <= p; ++j) {
if (find1(e[j].x) != find1(e[j].y)) {
f1[f1[e[j].x]] = f1[e[j].y];
}
}
for (int j = i; j <= p; ++j) {
if (find2(e[j].x) == find2(e[j].y)) {
continue;
}
if (!id[find1(e[j].x)]) {
id[find1(e[j].x)] = ++tot;
}
v[id[f1[e[j].x]]].push_back(e[j]);
}
for (int j = i; j <= p; ++j) {
id[f1[e[j].x]] = 0;
}
for (int j = 1; j <= tot; ++j) {
ANS = (long long) ANS * calc(j) % MOD;
v[j].clear();
}
}
cout << ANS << '\n';
return 0;
}