【YBTOJ高效进阶 21189】最短距离(最短路)(最小生成树)


最短距离

题目链接:YBTOJ高效进阶 21189

题目大意

给你一个无向图,每个点可能是黑色或白色,然后边有权值,你要保留一些边,费用是它们的权值和。
然后你要使得每个白点到任意一个黑点的最短路跟原来的图一样,然后要你求最小费用,或输出无解。

思路

考虑到是任意终点,我们考虑“合并”它们:
就用一个 \(0\) 号点,所有黑点连向它,边权是 \(0\),那我们从白点到 \(0\) 号点的最短距离就是到黑点的距离啦。

那剩下的就好搞了,在最短路图上跑个最小生成树即可。
(因为连向 \(0\) 点的不用费用,然后每个白点总要连黑点,而黑点一定会连着 \(0\) 点,所以就连通了)

代码

#include
#include
#include
#include
#define ll long long

using namespace std;

struct node {
	int x, y;
	ll z;
}e[300001];
int n, m, fa[100001], tot;
int op[100001], num;
ll ans, dis[100001];

struct DIJDIJ {//最短路
	struct nde {
		ll x;
		int to, nxt;
	}e[600001];
	int le[100001], KK;
	bool in[100001];
	priority_queue , vector >, greater > > q;
	
	void add(int x, int y, ll z) {
		e[++KK] = (nde){z, y, le[x]}; le[x] = KK;
	}
	
	void work() {
		memset(dis, 0x7f, sizeof(dis));
		dis[0] = 0; q.push(make_pair(0, 0));
		while (!q.empty()) {
			int now = q.top().second; q.pop();
			if (in[now]) continue; in[now] = 1;
			for (int i = le[now]; i; i = e[i].nxt)
				if (dis[e[i].to] > dis[now] + e[i].x) {
					dis[e[i].to] = dis[now] + e[i].x;
					q.push(make_pair(dis[e[i].to], e[i].to));
				}
		}
	}
}G;

bool cmp(node x, node y) {
	return x.z < y.z;
}

int find(int now) {
	if (fa[now] == now) return now;
	return fa[now] = find(fa[now]);
}

int main() {
//	freopen("minimum.in", "r", stdin);
//	freopen("minimum.out", "w", stdout);
	
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) {
		fa[i] = i;
		scanf("%d", &op[i]);
		if (op[i] == 1) {
			e[++tot + m] = (node){0, i, 0};
		}
	}
	for (int i = 1; i <= m; i++) {
		scanf("%d %d %d", &e[i].x, &e[i].y, &e[i].z);
	}
	
	for (int i = 1; i <= m + tot; i++)
		G.add(e[i].x, e[i].y, e[i].z), G.add(e[i].y, e[i].x, e[i].z);
	G.work();
	
	sort(e + 1, e + m + tot + 1, cmp);//最小生成树
	for (int i = 1; i <= m + tot; i++) {
		if (dis[e[i].x] == dis[e[i].y] + e[i].z || dis[e[i].y] == dis[e[i].x] + e[i].z) {
			int X = find(e[i].x), Y = find(e[i].y);
			if (X == Y) continue;
			fa[X] = Y;
			num++;
			ans += e[i].z;
			if (num == n) break;
		}
	}
	
	if (num != n) printf("impossible");
		else printf("%lld", ans);
	
	return 0;
}