12. Game on Tree 3


题目链接:Game on Tree 3

有一棵含有 \(n\) 个节点的树,节点编号从 \(1\)\(n\),根节点为 \(1\),所有非根节点均有一个正整数权值。根节点上放有一个棋子。T 和 A 两个人正在玩一个回合制游戏。一个回合中:

  • A 先选取一个非根节点,将其权值变为 \(0\)
  • 然后 T 将棋子移动到当前位置的任意一个儿子上
  • 若棋子位于叶子节点,游戏结束;T 也可以在此时强行结束游戏

游戏结束时 T 会获得棋子所在位置的权值的得分。T 想最大化得分,而 A 想最小化得分,问两人在最优策略下 T 最后的得分是多少。

先看官方题解:

由于验证 T 能否至少得到 \(x\) 分比较容易,我们可以二分他的得分。假设当前 check 的是至少得到 \(x\) 分的情况,则将树上权值小于 \(x\) 的节点染成白色,权值大于等于 \(x\) 的节点染成黑色,然后进行树形 dp:设 \(dp[u]\) 表示在以 \(u\) 为根节点的子树中 A 需要额外染色 \(dp[u]\) 次才能使 T 无法走到黑色节点。那么状态转移方程:

\[dp[u]=\max(\sum dp[v]-1,0)+[val_u\ge x] \]

其中 \(v\)\(u\) 的子节点。求和后减一是因为在 T 走下去之前还有一次变颜色的机会。

最后如果 \(dp[1]\gt 0\) 说明可以取到大于等于 \(x\) 的权值。

#include 
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
vector g[maxn];
int a[maxn], dp[maxn];
void dfs(int u, int f, int x) {
    dp[u] = 0;
    for (auto v : g[u]) {
        if (v == f)
            continue;
        dfs(v, u, x);
        dp[u] += dp[v];
    }
    dp[u] = max(dp[u] - 1, 0) + (a[u] >= x);
}
void solve() {
    int n;
    cin >> n;
    for (int i = 2; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 1, u, v; i < n; ++i) {
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    int l = 0, r = 1e9, ans = 0;
    while (l <= r) {
        int mid = (l + r) >> 1;
        dfs(1, 0, mid);
        if (dp[1] > 0)
            l = mid + 1, ans = mid;
        else
            r = mid - 1;
    }
    cout << ans << endl;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T = 1;
    // cin >> T;
    while (T--) {
        solve();
    }
}

这种染色的技巧在一些数据结构题中也有出现,在不太容易直接计算但比较容易 check 的情况下可以尝试一下。

那么能不能直接求出这个答案呢?我的室友给出了一种更为巧妙的做法:

先考虑树的高度是 \(1\) 的情况,显然 \(A\) 的最优选择是改变权值最大的那个叶子节点。

如果上面的这个东西是一个子树,那么它就会向父亲的地方输送除去这个权值之外的所有权值。然后又会产生一个改变权值的机会,所以就从这些权值里再删掉一个最大的。这个过程可以用可并堆来维护。

#include 
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
const ll mod = 998244353;
vector g[maxn];
int a[maxn];
int fa[maxn], ls[maxn], rs[maxn], d[maxn];
int findfa(int x) { 
    return fa[x] == x ? fa[x] : (fa[x] = findfa(fa[x]));
}
int merge(int x, int y) {
    if (!x || !y) {
        d[x] = d[y] = 0;
        return x + y;
    }
    if (a[x] < a[y]) 
        swap(x, y);
    rs[x] = merge(rs[x], y);
    if (d[ls[x]] < d[rs[x]])
        swap(ls[x], rs[x]);
    d[x] = d[rs[x]] + 1;
    return x;
}
int join(int x, int y) {
    x = findfa(x), y = findfa(y);
    if (x == y) 
        return x;
    fa[x] = fa[y] = merge(x, y);
    return fa[x];
}
int pop(int x) {
    x = findfa(x);
    int t = x, y = rs[x];
    x = ls[x];
    fa[x] = fa[y] = fa[t] = merge(x, y);
    return fa[x];
}
int dfs(int u, int f) {
    if (u != 1 && g[u].size() == 1)
        return u;
    int rt = 0;
    for (auto v: g[u]) {
        if (v == f)
            continue;
        int r = dfs(v, u);
        if (!rt)
            rt = r;
        else 
            rt = join(rt, r);
    }
    rt = pop(rt);
    return join(rt, u);
}
void solve() {
    int n;
    cin >> n;
    for (int i = 2; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 1; i <= n; ++i) {
        fa[i] = i;
    }
    for (int i = 1, u, v; i < n; ++i) {
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    int ans = dfs(1, 0);
    cout << a[ans] << endl;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T = 1;
    // cin >> T;
    while (T--) {
        solve();
    }
}