枇杷树


题意

\(m(\leq 300)\) 个操作,每次操作都会产生一个树的版本(从0开始), 一次操作把 \(x_i\) 版本的树的点 \(u\)\(y_i\) 版本的树的点 \(v\) 连一条权值是 \(w\) 的边(编号从0开始,\(y_i\)上的全部点的编号加 \(siz_{x_i}\)\((u,v\leq 10^{18}, siz \leq 2 \times 10 ^{18})\)

求每个版本:

\[\sum_{i = 0}^{\text{siz} - 1} \sum_{j = i + 1}^{\text{siz}} \text{dis}(i, j) \]

转化:

一般算所有路径的和可以转化为边的贡献

假设我们算出版本 \(i\) 和 版本 \(j\) 的答案,现在要求版本 \(k\) 的答案, \(i = x_k, j = y_k\)

  1. 显然版本内的路径和可以继承,这部分是 \(\text{ans}_i + \text{ans}_j\)

  2. 还有经过边 \((u, v, w)\) 的路径:

    1. \((u, v, w)\) 的贡献是 \(\text{siz}_u \times \text{siz}_v \times w\)
    2. 还有 \(i\), \(j\) 版本树内的路径贡献,不妨设 \(f_x(u)\) 为版本 \(x\) 的树中,所有点到 \(u\) 的距离和,
      这部分贡献是 \(f_i(u) \times \text{siz}_j + f_j(v) \times \text{siz}_i\)

就有转移式子:

\[\text{ans}_k = \text{ans}_i + \text{ans}_j + \text{siz}_u \times \text{siz}_v \times w + f_i(u) \times \text{siz}_j + f_j(v) \times \text{siz}_i \]

考虑计算 \(f_k(X)\), \(i = x_k, j = y_k\), 不妨设 \(\text{dis}_i(j,k)\) 表示版本 \(i\) 中,\(j\)\(k\) 的路径:

  1. \(X\)\(i\) 版本中的树:

    1. \(i\) 树内的贡献:\(f_i(X)\)
    2. \(j\) 树内的贡献,要经过边 \((u, v, w)\) 和路径 \(u \to X\): \(f_j(v) + (w + \text{dis}_i(v, X)) \times \text{siz}_j\)

    \[f_k(X) = f_i(X) + f_j(v) + (w + \text{dis}_i(u, X)) \times \text{siz}_j \]

  2. \(X\)\(j\) 版本中的树, 同理:

    \[f_k(X) = f_j(X) + f_i(u) + (w + \text{dis}_j(v, X)) \times \text{siz}_i \]

考虑计算 \(\text{dis}_k(X, Y)\):

  1. \(X\), \(Y\) 在同一个版本,直接找子版本:

    \[\text{dis}_i(X, Y) / \text{dis}_j(X, Y) \]

  2. \(X\), \(Y\) 在不同版本,假设 \(X\)\(i\), \(Y\)\(j\)

\[\text{dis}_i(X, u) + w + \text{dis}_j(v, Y) \]

就能算出距离了。

分析

对于 \(f_k(X)\), 最多有 \(k = m\) 个版本,\(X\) 的取值对多有 \(2 \times m\) 个,总的状态就是 \(2 \times m \times m\) 个, 转移时是 \(O(1)\) 的,\(\text{dis}_i(j, k)\)的计算平均是 \(O(1)\) 的,所以计算 \(f_k(X)\)\(O(m^2)\) 的。

计算 \(\text{ans}\) 时, \(f_k(x)\) 会被调用 \(O(m)\) 次.

总的时间复杂度是 \(O(m^3)\)

代码

//#pragma GCC diagnostic error "-std=c++11"
#include
using namespace std;

using ll = long long;
const int MAXN = 1200010;
const int INF = 0x7fffffff;
const int mod = 1000000007;

template 
void Read(T &x) {
	x = 0; T f = 1; char a = getchar();
	for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
	for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
	x *= f;
}

ll add(ll a, ll b) {
	return ((a + b) % mod + mod) % mod;
}

ll mul(ll a, ll b) {
	return  (a % mod) * (b % mod) % mod; 
}

int m;
int x[MAXN], y[MAXN];
ll u[MAXN], v[MAXN], w[MAXN], siz[MAXN], ans[MAXN];

map > > G;
ll dis(int k, ll X, ll Y) {
	if (k == 0) return 0; 
	if (X == Y) return 0; 
	if (G[k][X][Y]) return G[k][X][Y]; 
	if (X < siz[x[k]]) {
		if (Y < siz[x[k]]) {
			return G[k][X][Y] = dis(x[k], X, Y); 
		} else {
			return G[k][X][Y] = add(add(dis(x[k], X, u[k]), dis(y[k], v[k], Y - siz[x[k]])), w[k]); 
		}
	} else {
		if (Y < siz[x[k]]) {
			return G[k][X][Y] = add(add(dis(x[k], Y, u[k]), dis(y[k], v[k], X - siz[x[k]])), w[k]); 
		} else {
			return G[k][X][Y] = dis(y[k], X - siz[x[k]], Y - siz[x[k]]);
		}
	}
}

map  > F; 
ll f(int k, ll X) 
{
	if (k == 0) return 0;
	if (F[k][X]) return F[k][X];
	if (X < siz[x[k]]) return F[k][X] = add( add( f(x[k], X), f(y[k], v[k])), mul(siz[y[k]], add(w[k], dis(x[k], u[k], X))));
	else return F[k][X] = add( add(f(x[k], u[k]), f(y[k], X - siz[x[k]])), mul(siz[x[k]], add(w[k], dis(y[k], v[k], X - siz[x[k]]))));
}

int main() { 
	cin >> m;
	siz[0] = 1;
	for (int i = 1; i <= m; i ++) {
		cin >> x[i] >> y[i] >> u[i] >> v[i] >> w[i]; 
		ans[i] = add(add(ans[x[i]], ans[y[i]]), add(mul(mul(siz[x[i]], siz[y[i]]), w[i]), add(mul(f(x[i], u[i]), siz[y[i]]), mul(f(y[i], v[i]), siz[x[i]]))));
		siz[i] = siz[x[i]] + siz[y[i]]; 
		cout << ans[i] << endl;
	}
 	return 0;
} 
/*
4
0 0 0 0 3
1 0 1 0 4
2 1 0 1 5
1 1 0 0 2
*/