#树形依赖背包,点分治#BZOJ 4182 Shopping


题目

给定一棵大小为 \(n\) 的树,每个点代表一种物品,其具有体积、价值和数量的属性,

现在选择一个连通块,使得里面所有点都被选中且体积不超过 \(m\),问最大价值。

\(n\leq 500,m\leq 4000\)


分析

树形背包比较难维护,考虑用dfs序拍平到序列上,并且多重背包直接二进制拆分。

\(dp[i][j]\) 表示dfs序为 \(i\),且选择体积为 \(j\) 时能获得的最大价值。

如果不选这个点,那么 \(dp[i][j]=dp[rfn[i]][j]\)\(rfn\) 表示这个点的下一个兄弟的dfs序

如果选择这个点,那么 \(dp[i][j]=\max\{dp[i+1][j-w]+c\}\)

但有一个问题就是这样会变成01背包,考虑先用上式更新一次(强制必选一个),再用 \(dp[i][j-w]\) 更新。

就是在二进制拆分时先拆一个再正常拆,发现这样根节点强制必选,那么跑点分治,所有连通块都能被以当前根节点的情况所表示。

可以通过二进制拆分的个数来决定点的大小求带权重心,这样时间复杂度为 \(O(Tnm\log n\log m)\)


代码

#include 
#include 
using namespace std;
const int N=511; struct node{int y,next;}e[N<<1]; struct rec{int w,c;}a[N<<3];
int siz[N],big[N],as[N],L[N],R[N],w[N],c[N],ans,root,tot,v[N],dfn[N],nfd[N],rfn[N],et=1,n,m,k,dp[N][N<<3];
int iut(){
	int ans=0; char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=ans*10+c-48,c=getchar();
	return ans;
}
void print(int ans){
	if (ans>9) print(ans/10);
	putchar(ans%10+48);
}
void Max(int &x,int y){x=x>y?x:y;}
void dfs(int x,int fa){
	siz[x]=R[x]-L[x]+1,big[x]=0;
	for (int i=as[x];i;i=e[i].next)
	if (e[i].y!=fa&&!v[e[i].y]){
		dfs(e[i].y,x),siz[x]+=siz[e[i].y];
		Max(big[x],siz[e[i].y]);
	}
	Max(big[x],big[0]-siz[x]);
	if (big[x]<=big[root]) root=x;
}
void calc(int x,int fa){
	dfn[x]=++tot,nfd[tot]=x;
	for (int i=as[x];i;i=e[i].next)
	if (e[i].y!=fa&&!v[e[i].y])
		calc(e[i].y,x);
	rfn[x]=tot+1;
}
void Dp(int x){
	v[x]=1,tot=0,calc(x,0);
	for (int i=0;i<=k;++i) dp[tot+1][i]=0;
	for (int i=tot;i;--i){
		int x=nfd[i];
		for (int j=0;j<=k;++j) dp[i][j]=dp[rfn[x]][j];
		for (int j=k;j>=w[x];--j)
		    Max(dp[i][j],dp[i+1][j-w[x]]+c[x]);
		for (int o=R[x];o>L[x];--o)
		for (int j=k;j>=a[o].w;--j)
		    Max(dp[i][j],dp[i][j-a[o].w]+a[o].c);
	}
	Max(ans,dp[1][k]);
	for (int i=as[x];i;i=e[i].next)
	if (!v[e[i].y]){
		big[0]=siz[e[i].y];
		dfs(e[i].y,root=0),Dp(root);
	}
}
int main(){
	for (int T=iut();T;--T){
		n=iut(),k=iut(),ans=m=0,et=1;
		for (int i=1;i<=n;++i) c[i]=iut();
		for (int i=1;i<=n;++i) w[i]=iut();
		for (int i=1;i<=n;++i){
			int x=iut(); L[i]=R[i-1]+1;
			a[++m]=(rec){w[i],c[i]},--x;
			for (int t=1;x>=t;x-=t,t<<=1)
				a[++m]=(rec){w[i]*t,c[i]*t};
			if (x) a[++m]=(rec){w[i]*x,c[i]*x};
			R[i]=m;
		}
		for (int i=1;i

相关