【P1040 [NOIP2003 提高组] 加分二叉树】题解


题目链接

题目

设一个 \(n\) 个节点的二叉树 \(\text{tree}\) 的中序遍历为\((1,2,3,\ldots,n)\),其中数字 \(1,2,3,\ldots,n\) 为节点编号。每个节点都有一个分数(均为正整数),记第 \(i\) 个节点的分数为 \(d_i\)\(\text{tree}\) 及它的每个子树都有一个加分,任一棵子树 \(\text{subtree}\)(也包含 \(\text{tree}\) 本身)的加分计算方法如下:

\(\text{subtree}\) 的左子树的加分 \(\times\) \(\text{subtree}\) 的右子树的加分 \(+\) \(\text{subtree}\) 的根的分数。

若某个子树为空,规定其加分为 \(1\),叶子的加分就是叶节点本身的分数。不考虑它的空子树。

试求一棵符合中序遍历为 \((1,2,3,\ldots,n)\) 且加分最高的二叉树 \(\text{tree}\)。要求输出

  1. \(\text{tree}\) 的最高加分。

  2. \(\text{tree}\) 的前序遍历。

思路

区间dp。

\(dp(l,r)\) 表示在中序遍历为 \(l\cdots r\) 时的最大分数,那么我们可以枚举根节点 \(k\)

\[\Large dp(l,r)=\max_{k=l}^r(a_k+dp(l, k-1)\times dp(k+1, r)) \]

转移过程中存储一下由哪个转移过来的,就可以输出先序遍历了。

总结

这道题标签上说是树形dp,但我用了区间dp。

这道题让我记起了三个遍历:

  1. 先序遍历:根左右
  2. 中序遍历:左根右
  3. 后序遍历:左右根

下次不要再搞混啦!

Code

// Problem: P1040 [NOIP2003 提高组] 加分二叉树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P1040
// Memory Limit: 128 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include
using namespace std;
#define int long long
inline int read(){int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;
ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+
(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define N 35
//#define M
//#define mo
int n, m, i, j, k; 
int dp[N][N], a[N], g[N][N]; 
int len, l, r; 

void dfs(int l, int r)
{
	if(l>r) return ; 
	printf("%lld ", g[l][r]); 
	dfs(l, g[l][r]-1); 
	dfs(g[l][r]+1, r); 
}

signed main()
{
//	freopen("tiaoshi.in","r",stdin);
//	freopen("tiaoshi.out","w",stdout);
	n=read(); 
	for(i=1; i<=n; ++i)
	{
		a[i]=read(); 
		dp[i][i]=a[i]; 
		dp[i-1][i]=a[i-1]+a[i]; 
		g[i-1][i]=i-1; g[i][i]=i; 
	}
	for(len=3; len<=n; ++len)
		for(l=1, r=len; r<=n; ++l, ++r)
		{
			dp[l][r]=max(a[l]+dp[l+1][r], a[r]+dp[l][r-1]); 
			g[l][r]=(a[l]+dp[l+1][r]>a[r]+dp[l][r-1] ? l : r); 
			for(k=l+1; k<=r-1; ++k)
				if(a[k]+dp[l][k-1]*dp[k+1][r]>dp[l][r])
					dp[l][r]=a[k]+dp[l][k-1]*dp[k+1][r], 
					g[l][r]=k; 
			// printf("dp[%lld][%lld]=%lld\n", l, r, dp[l][r]); 
		}
	printf("%lld\n", dp[1][n]); 
	dfs(1, n); 
	return 0;
}