【luogu P4719】【模板】“动态 DP“&动态树分治(DDP)(重链剖分)(线段树)
【模板】"动态 DP"&动态树分治
题目链接:luogu P4719
题目大意
给你一棵树,点带权,每次操作修改点权,要你求最大权独立集的权值大小。
思路
这道题就是 DDP 的模板题啦。
(其实还多了一个树剖)
DDP 是什么呢,就是动态 DP,就是一些简单的 DP 在加上了一个带修,然后我们用矩阵乘法来乘模拟它转移的过程,然后用数据结构来维护矩阵的乘,从而快速处理修改。
那我们考虑先一开始的 DP。
设 \(f_{i,0/1}\) 表示 \(i\) 的子树,\(i\) 这个点不选 / 选的最大权独立集。
然后转移就是枚举儿子 \(j\):\(f_{i,0}=\sum\limits_{j}\max(f_{j,0},f_{j,1}),f_{i,1}=a_i+\sum\limits_{j}f_{j,0}\)
然后答案是 \(\max(f_{1,0},f_{1,1})\)(以 \(1\) 为根做树形 DP)
然后考虑修改之后有哪些部分要改,那就是修改的点往上都根的链。
那修改链嘛,不难想到一个东西叫做树链剖分,那我们就直接上重链剖分。
那既然要这样了,我们转移肯定不能带个 \(\sum\),不然你怎么优化,所以我们考虑根据轻重儿子弄一个 \(g_{i,0/1}\) 表示 \(i\) 的子树,\(i\) 的轻儿子可选可不选 / 都不能选的最大权独立集。
那我们设 \(j\) 是 \(i\) 的重儿子:\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1}),f_{i,1}=a_i+g_{i,1}+f_{j,1}\)
然后发现 \(f_{i,1}\) 里面 \(a_i,g_{i,1}\) 两个下标都是 \(i\) 相关的,考虑把 \(a_i\) 放进 \(g_{i,1}\) 里面。
那 \(g_{i,1}\) 就表示 \(i\) 的子树,\(i\) 的轻儿子都不选,\(i\) 选的最大权独立集。
然后重新写一次:
\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1})\)
\(f_{i,1}=g_{i,1}+f_{j,1}\)
那轻边的转移就好了,接着考虑重链要怎么线段树维护。
看到这种简洁的东西我们考虑怎么用矩阵乘法来优化。
发现问题出在 \(\max\) 上,普通的矩阵乘法搞不了这玩意儿。
那我们就大胆重新定义矩阵乘法咯。
定义一个 \(*\),\(A*B=\max\limits_{k}\{A_{i,k},B_{k,j}\}\)
记得验证结合律:
你 \(\max(\max(x,y),z)=\max(x,y,z)\),\(\max(x,\max(y,z))=\max(x,y,z)\),那肯定一样啦。
(其实就是因为 \(\max\) 是满足结合律的)
然后你想想能不能用这个来搞。
首先第一个的样子肯定得改:\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1})=\max(g_{i,0}+f_{j,0},g_{i,0}+f_{j,1})\)
那第二个我们就可以看做是 \(f_{i,1}=\max(g_{i,1}+f_{j,1},-\infty)\)
那矩阵是不是可以是:
\(\begin{vmatrix}g_{i,0}&g_{i,1}\\ g_{i,0}&\infty\end{vmatrix}\)
好像可以?
那我们就有:\(\begin{vmatrix}f_{j,0}&f_{j,1}\end{vmatrix}*\begin{vmatrix}g_{i,0}&g_{i,1}\\ g_{i,0}&\infty\end{vmatrix}=\begin{vmatrix}f_{i,0}&f_{i,1}\end{vmatrix}\)
那就可以啦!
然后你会发现直接上就错了,因为你这个树形 DP 是自下而上的,那线段树来说你是从左到右,也就是从上到下的。
那接下来就是两个解决方法:
- 直接改乘起来的顺序,线段树里面改,就合并的时候是 \(val_{2x+1}*val_{2x}\),然后查询答案的时候也是右边的答案乘上左边的。
- 直接改矩阵的样子,让它满足:\(Z*\begin{vmatrix}f_{j,0}&f_{j,1}\end{vmatrix}=\begin{vmatrix}f_{i,0}&f_{i,1}\end{vmatrix}\)(\(Z\) 就是要改成的矩阵)
然后再这里其实不难,就原来的矩阵重心对称一下就好了:\(\begin{vmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&\infty\end{vmatrix}\)
然后搞就可以啦。
代码
改乘起来顺序形
#include
#include
#include
using namespace std;
const int N = 1e5 + 100;
const int M = 2;
int n, m, a[N], x, y;
vector G[N];
struct matrix {
int a[M][M];
matrix() {
memset(a, -0x3f, sizeof(a));
}
matrix operator *(matrix y) {
matrix re;
for (int k = 0; k < 2; k++)
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
re.a[i][j] = max(re.a[i][j], a[i][k] + y.a[k][j]);
return re;
}
};
int fa[N], sz[N], son[N], top[N], dfn[N], id[N], End[N], f[N][2];
matrix val[N];
void dfs0(int now, int father) {
sz[now] = 1; fa[now] = father;
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; if (x == father) continue;
dfs0(x, now); sz[now] += sz[x];
if (sz[x] > sz[son[now]]) son[now] = x;
}
}
void dfs1(int now, int father) {
f[now][0] = 0; f[now][1] = a[now];
val[now].a[0][0] = val[now].a[1][0] = 0;
val[now].a[0][1] = a[now];
dfn[++dfn[0]] = now; id[now] = dfn[0];
if (son[now]) {
top[son[now]] = top[now]; dfs1(son[now], now);
f[now][0] += max(f[son[now]][0], f[son[now]][1]);
f[now][1] += f[son[now]][0];
}
else End[top[now]] = now;
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; if (x == father || x == son[now]) continue;
top[x] = x; dfs1(x, now);
f[now][0] += max(f[x][0], f[x][1]);
f[now][1] += f[x][0];
val[now].a[0][0] += max(f[x][0], f[x][1]); val[now].a[1][0] += max(f[x][0], f[x][1]);
val[now].a[0][1] += f[x][0];
}
}
struct XD_tree {
matrix v[N << 2];
void up(int now) {
v[now] = v[now << 1 | 1] * v[now << 1];
}
void build(int now, int l, int r) {
if (l == r) {
v[now] = val[dfn[l]]; return ;
}
int mid = (l + r) >> 1; build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
up(now);
}
void change(int now, int l, int r, int pl) {
if (l == r) {
v[now] = val[dfn[l]]; return ;
}
int mid = (l + r) >> 1;
if (pl <= mid) change(now << 1, l, mid, pl);
else change(now << 1 | 1, mid + 1, r, pl);
up(now);
}
matrix query(int now, int l, int r, int L, int R) {
if (L <= l && r <= R) return v[now];
int mid = (l + r) >> 1;
if (L <= mid && mid < R) return query(now << 1 | 1, mid + 1, r, L, R) * query(now << 1, l, mid, L, R);
if (L <= mid) return query(now << 1, l, mid, L, R);
if (mid < R) return query(now << 1 | 1, mid + 1, r, L, R);
}
void update(int now, int va) {
val[now].a[0][1] -= a[now]; val[now].a[0][1] += va; a[now] = va;//直接记每个数组当前的值,这样线段树就不用下传矩阵了
matrix bef, aft;
while (now) {
bef = query(1, 1, n, id[top[now]], id[End[top[now]]]);
change(1, 1, n, id[now]);
aft = query(1, 1, n, id[top[now]], id[End[top[now]]]);
now = fa[top[now]];
if (!now) break;
val[now].a[0][0] -= max(bef.a[0][0], bef.a[0][1]); val[now].a[0][0] += max(aft.a[0][0], aft.a[0][1]);
val[now].a[1][0] -= max(bef.a[0][0], bef.a[0][1]); val[now].a[1][0] += max(aft.a[0][0], aft.a[0][1]);
val[now].a[0][1] -= bef.a[0][0]; val[now].a[0][1] += aft.a[0][0];
}
}
}T;
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y); G[x].push_back(y); G[y].push_back(x);
}
dfs0(1, 0); top[1] = 1; dfs1(1, 0);
T.build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d %d", &x, &y);
T.update(x, y);
matrix ans = T.query(1, 1, n, id[1], id[End[1]]);
//记得是从下向上搞,所以你要记录一个end表示这条重链的下面,然后一个点的答案是从它链下面到它
printf("%d\n", max(ans.a[0][0], ans.a[0][1]));
}
return 0;
}
改矩阵形
#include
#include
#include
using namespace std;
const int N = 1e5 + 100;
const int M = 2;
int n, m, a[N], x, y;
vector G[N];
struct matrix {
int a[M][M];
matrix() {
memset(a, -0x3f, sizeof(a));
}
matrix operator *(matrix y) {
matrix re;
for (int k = 0; k < 2; k++)
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
re.a[i][j] = max(re.a[i][j], a[i][k] + y.a[k][j]);
return re;
}
};
int fa[N], sz[N], son[N], top[N], dfn[N], id[N], End[N], f[N][2];
matrix val[N];
void dfs0(int now, int father) {
sz[now] = 1; fa[now] = father;
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; if (x == father) continue;
dfs0(x, now); sz[now] += sz[x];
if (sz[x] > sz[son[now]]) son[now] = x;
}
}
void dfs1(int now, int father) {
f[now][0] = 0; f[now][1] = a[now];
val[now].a[0][0] = val[now].a[0][1] = 0;
val[now].a[1][0] = a[now];
dfn[++dfn[0]] = now; id[now] = dfn[0];
if (son[now]) {
top[son[now]] = top[now]; dfs1(son[now], now);
f[now][0] += max(f[son[now]][0], f[son[now]][1]);
f[now][1] += f[son[now]][0];
}
else End[top[now]] = now;
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; if (x == father || x == son[now]) continue;
top[x] = x; dfs1(x, now);
f[now][0] += max(f[x][0], f[x][1]);
f[now][1] += f[x][0];
val[now].a[0][0] += max(f[x][0], f[x][1]); val[now].a[0][1] += max(f[x][0], f[x][1]);
val[now].a[1][0] += f[x][0];
}
}
struct XD_tree {
matrix v[N << 2];
void up(int now) {
v[now] = v[now << 1] * v[now << 1 | 1];
}
void build(int now, int l, int r) {
if (l == r) {
v[now] = val[dfn[l]]; return ;
}
int mid = (l + r) >> 1; build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
up(now);
}
void change(int now, int l, int r, int pl) {
if (l == r) {
v[now] = val[dfn[l]]; return ;
}
int mid = (l + r) >> 1;
if (pl <= mid) change(now << 1, l, mid, pl);
else change(now << 1 | 1, mid + 1, r, pl);
up(now);
}
matrix query(int now, int l, int r, int L, int R) {
if (L <= l && r <= R) return v[now];
int mid = (l + r) >> 1;
if (L <= mid && mid < R) return query(now << 1, l, mid, L, R) * query(now << 1 | 1, mid + 1, r, L, R);
if (L <= mid) return query(now << 1, l, mid, L, R);
if (mid < R) return query(now << 1 | 1, mid + 1, r, L, R);
}
void update(int now, int va) {
val[now].a[1][0] -= a[now]; val[now].a[1][0] += va; a[now] = va;//直接记每个数组当前的值,这样线段树就不用下传矩阵了
matrix bef, aft;
while (now) {
bef = query(1, 1, n, id[top[now]], id[End[top[now]]]);
change(1, 1, n, id[now]);
aft = query(1, 1, n, id[top[now]], id[End[top[now]]]);
now = fa[top[now]];
if (!now) break;
val[now].a[0][0] -= max(bef.a[0][0], bef.a[1][0]); val[now].a[0][0] += max(aft.a[0][0], aft.a[1][0]);
val[now].a[0][1] -= max(bef.a[0][0], bef.a[1][0]); val[now].a[0][1] += max(aft.a[0][0], aft.a[1][0]);
val[now].a[1][0] -= bef.a[0][0]; val[now].a[1][0] += aft.a[0][0];
}
}
}T;
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y); G[x].push_back(y); G[y].push_back(x);
}
dfs0(1, 0); top[1] = 1; dfs1(1, 0);
T.build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d %d", &x, &y);
T.update(x, y);
matrix ans = T.query(1, 1, n, id[1], id[End[1]]);
//记得是从下向上搞,所以你要记录一个end表示这条重链的下面,然后一个点的答案是从它链下面到它
printf("%d\n", max(ans.a[0][0], ans.a[1][0]));
}
return 0;
}