Codeforces 856D - Masha and Cactus(树链剖分优化 dp)


题面传送门

题意:

  • 给你一棵 \(n\) 个顶点的树和 \(m\) 条带权值的附加边
  • 你要选择一些附加边加入原树中使其成为一个仙人掌(每个点最多属于 \(1\) 个简单环)
  • 求你选择的附加边权值之和的最大值。
  • \(n \in [3,2 \times 10^5]\)\(m \in [0,2 \times 10^5]\)

很容易发现,如果你选择一条端点为 \(u,v\) 的附加边,那么 \(u\)\(v\) 路径上所有点都会被包含进一个简单环。
那么题目变为:有 \(m\) 条路径,你要选择其中某些路径使得每个点最多被一条路径覆盖。问选择的路径权值之和的最大值。
对于这一类问题,我们可以想到树形 \(dp\)。记 \(dp_u\) 为:选择一些路径,路径上所有点都在 \(u\) 的子树内,并且路径两两不相交,这些路径权值之和的最大值。
转移分两种情况:

  1. \(u\) 没有被覆盖,那么 \(dp_u=\sum\limits_{fa_v=u}dp_v\)
  2. \(u\) 被覆盖。那么显然它只能被 \(LCA=u\) 的路径覆盖(因为根据之前的定义,路径上所有节点都必须在 \(u\) 的子树内)。这种情况下就是 \(u\) 的子树挖掉这条路径后剩余部分的答案+这条路径的权值。
    是不是听起来有些抽象?不妨画个图理解看:
    假设有如下图的树,你选择了由橙色边组成的路径:

    我们把这条路径上的点全部去掉看看剩余部分是个什么东西。

    凭直觉,这部分的贡献为 \(dp[4]+dp[5]+dp[6]+dp[7]+dp[10]+dp[13]+dp[16]+dp[17]+dp[19]+dp[21]+dp[24]\)
    发现了什么了吗?
    你可能会说,项数这么多,我啥也没发现。
    那我们就把删掉的点补全后看看呗。

    通过观察,我们发现:
  • \(4,5\)\(3\) 的儿子
  • \(6,7\)\(2\) 除了 \(3\) 之外的儿子
  • \(10,13\)\(1\) 除了 \(2,15\) 之外的儿子
  • \(16,17\)\(15\) 除了 \(18\) 之外的儿子
  • \(19\)\(18\) 除了 \(20\) 之外的儿子
  • \(21,24\)\(20\) 的儿子。
    我们记 \(p_1 \to p_2 \to \dots \to p_k\) 为我们选择的路径,其中 \(p_j=i\)\(p_0=p_{k+1}=0\),那么贡献为:

\[\sum\limits_{i=1}^j\sum\limits_{fa_v=p_i}dp_v \times [v \neq p_{i-1}]+\sum\limits_{fa_v=u}dp_v \times [v \neq p_{i-1} \and v \neq p_{i+1}]+\sum\limits_{i=j+1}^k\sum\limits_{fa_v=p_i}dp_v \times [v \neq p_{i+1}]+w \]

暴力显然是 \(n^2\) 的,考虑对其进行优化。
\(g_x=\sum\limits_{fa_y=x}dp_y-dp_x\),那么上式子等价于

\[\sum\limits_{i=1}^{j-1}g_{p_i}+\sum\limits_{fa_v=u}dp_v+\sum\limits_{i=j+1}^kg_{p_i}+w \]

用树链剖分随便维护一下就可以了

/*
Contest: -
Problem: Codeforces 856D
Author: tzc_wk
Time: 2020.8.20
*/
#include 
using namespace std;
#define fi			first
#define se			second
#define pb			push_back
#define fz(i,a,b)	for(int i=a;i<=b;i++)
#define fd(i,a,b)	for(int i=a;i>=b;i--)
#define foreach(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define all(a)		a.begin(),a.end()
#define fill0(a)	memset(a,0,sizeof(a))
#define fill1(a)	memset(a,-1,sizeof(a))
#define fillbig(a)	memset(a,0x3f,sizeof(a))
#define y1			y1010101010101
#define y0			y0101010101010
typedef pair pii;
typedef long long ll;
inline int read(){
	int x=0,neg=1;char c=getchar();
	while(!isdigit(c)){
		if(c=='-') neg=-1;
		c=getchar();
	}
	while(isdigit(c)) x=x*10+c-'0',c=getchar();
	return x*neg;
}
int n=read(),m=read();
int head[200005],to[400005],nxt[400005],ecnt=0;
inline void adde(int u,int v){
	to[++ecnt]=v;nxt[ecnt]=head[u];head[u]=ecnt;
}
int fa[200005],top[200005],dep[200005],siz[200005],wson[200005],dfn[200005],nd[200005],tim=0;
inline void dfs1(int x){
	siz[x]=1;
	for(int i=head[x];i;i=nxt[i]){
		int y=to[i];dep[y]=dep[x]+1;dfs1(y);siz[x]+=siz[y];
		if(siz[y]>siz[wson[x]]) wson[x]=y;
	}
}
inline void dfs2(int x,int tp){
	top[x]=tp;dfn[x]=++tim;nd[tim]=x;
	if(wson[x]) dfs2(wson[x],tp);
	for(int i=head[x];i;i=nxt[i]){
		int y=to[i];if(y!=wson[x]) dfs2(y,y);
	}
}
inline int getlca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]] t[200005];
struct node{
	int l,r,val;
} s[200005<<2];
inline void build(int k,int l,int r){
	s[k].l=l;s[k].r=r;
	if(l==r) return;
	int mid=(l+r)>>1;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
}
inline void modify(int k,int ind,int val){
	if(s[k].l==s[k].r){s[k].val+=val;return;}
	int mid=(s[k].l+s[k].r)>>1;
	if(ind<=mid) modify(k<<1,ind,val);
	else modify(k<<1|1,ind,val);
	s[k].val=s[k<<1].val+s[k<<1|1].val;
}
inline int query(int k,int l,int r){
	if(l<=s[k].l&&s[k].r<=r) return s[k].val;
	int mid=(s[k].l+s[k].r)>>1;
	if(r<=mid) return query(k<<1,l,r);
	else if(l>mid) return query(k<<1|1,l,r);
	else return query(k<<1,l,mid)+query(k<<1|1,mid+1,r);
}
inline int calc(int u,int v){
	int sum=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]u,it->v)+it->w+sum);
	modify(1,dfn[x],sum-dp[x]);
}
int main(){
	for(int i=2;i<=n;i++){fa[i]=read();adde(fa[i],i);}
	dfs1(1);dfs2(1,1);build(1,1,n);
	for(int i=1;i<=m;i++){
		int u=read(),v=read(),w=read();
		t[getlca(u,v)].pb(path(u,v,w));
	}
	dfs3(1);printf("%d\n",dp[1]);
	return 0;
}