AcWing 1073. 树的中心
题目传送门
一、思路分析
这个问题是 树形DP 中的一类 经典模型,常被称作 换根DP
同样,先来想一下如何暴力求解该问题:先 枚举 目标节点,然后求解该节点到其他节点的 最远距离
时间复杂度为 \(O(n^2)\),对于本题的 数据规模,十分极限,经测试只能过 7/11,代码见最下面
考虑如何优化求解该问题的方法
思考一下:在确定树的 拓扑结构 后单独求一个节点的 最远距离 时,会在该树上去比较哪些 路径 呢?
-
从当前节点往下,直到子树中某个节点的最长路径
-
从当前节点往上走到其父节点,再从其父节点出发且不回到该节点的最长路径
此处就要引入 换根DP 的思想了
换根DP 一般分为三个步骤:
- 指定任意一个根节点
- 一次dfs遍历,统计出当前子树内的节点对当前节点的贡献
- 一次dfs遍历,统计出当前节点的父节点对当前节点的贡献,然后合并统计答案
那么我们就要先 dfs 一遍,预处理出当前子树对于根的最大贡献(距离)和 次大贡献(距离)
处理 次大贡献(距离) 的原因是:
如果 当前节点 是其 父节点子树 的 最大路径 上的点,则 父节点子树 的 最大贡献 不能算作对该节点的贡献
因为我们的路径是 简单路径,不能 走回头路
然后我们再 dfs 一遍,求解出每个节点的父节点对他的贡献(即每个节点往上能到的最远路径),两者比较,取一个 max即可
d1[u]
:存下u
节点向下走的最长路径的长度
d2[u]
:存下u
节点向下走的第二长的路径的长度
p1[u]
:存下u
节点向下走的最长路径是从哪一个节点下去的
p2[u]
:存下u
节点向下走的第二长的路径是从哪一个节点走下去的
up[u]
:存下u
节点向上走的最长路径的长度
二、实现代码
#include
using namespace std;
const int N = 10010;
const int M = N << 1;
const int INF = 0x3f3f3f3f;
int n;
int h[N], e[M], w[M], ne[M], idx;
int d1[N]; //下行最长距离
int d2[N]; //下行次长距离
int up[N]; //上行最长距离
int p1[N]; //下行最长距离是走的哪一个节点获得的
//邻接表模板
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
//统计出当前子树内的节点对当前节点的贡献
void dfs_d(int u, int father) {
for (int i = h[u]; i != -1; i = ne[i]) {//遍历每条出边
int j = e[i]; //连接的节点j
if (j == father) continue; //不走回头路
dfs_d(j, u); //换根递归计算以j为根的子树情况
//dfs1的结果其实已经记录到d1[j]里
if (d1[j] + w[i] >= d1[u]) { //如果可以获得更大的距离
d2[u] = d1[u]; //更新次长和最长的值
d1[u] = d1[j] + w[i];
p1[u] = j; //更新u节点的最长距离来源节点j
} else if (d1[j] + w[i] > d2[u]) //如果可以更新次长
d2[u] = d1[j] + w[i]; //那就更新
}
}
//统计出当前节点的父节点对当前节点的贡献,然后合并统计答案
void dfs_u(int u, int father) {
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
//j是u的子节点,这里在求j向上走的最长路。
//分两种情况,如果u向下的最长路经过j,则用次长路更新;否则用最长路更新。
if (p1[u] == j) up[j] = max(up[u], d2[u]) + w[i]; //用次大更新
else up[j] = max(up[u], d1[u]) + w[i]; //用最大更新
//讨论以j为根情况
dfs_u(j, u);
}
}
int main() {
//初始化邻接表
memset(h, -1, sizeof h);
cin >> n;
for (int i = 1; i < n; i++) {//n-1条边
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
//换根DP,两次DFS
dfs_d(1, -1);
dfs_u(1, -1);
//遍历每一个节点,找出它的最大上行距离和最大下行距离,然后取最小值
int res = INF;
for (int i = 1; i <= n; i++) res = min(res, max(d1[i], up[i]));
//输出
printf("%d\n", res);
return 0;
}
三、暴力解法
#include
using namespace std;
//直接暴力换根
//暴力办法,可以通过 7/11个数据
const int N = 10010;
const int M = N * 2;
const int INF = 0x3f3f3f3f;
int n;
int h[N], e[M], w[M], ne[M], idx;
int d[N];
//邻接表模板
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
int dfs(int u, int father) {
d[u] = 0;//最长距离初始化为0
for (int i = h[u]; i != -1; i = ne[i]) {//遍历u节点的每一条出边
int j = e[i];
if (j == father) continue; //不走回头路
int dist = dfs(j, u) + w[i];//i表示u->j的边
d[u] = max(d[u], dist); //获取最长距离
}
return d[u];
}
int main() {
//邻接表初始化
memset(h, -1, sizeof h);
cin >> n;
for (int i = 1; i < n; i++) {//n-1条边
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
int res = INF;
//从每一个点出发,分别求一个此点到其它各点的最长距离,然后求一个min
for (int i = 1; i <= n; i++) res = min(res, dfs(i, -1));
//输出
printf("%d\n", res);
return 0;
}