枇杷树
题意
\(m(\leq 300)\) 个操作,每次操作都会产生一个树的版本(从0开始), 一次操作把 \(x_i\) 版本的树的点 \(u\) 和 \(y_i\) 版本的树的点 \(v\) 连一条权值是 \(w\) 的边(编号从0开始,\(y_i\)上的全部点的编号加 \(siz_{x_i}\) )\((u,v\leq 10^{18}, siz \leq 2 \times 10 ^{18})\)。
求每个版本:
\[\sum_{i = 0}^{\text{siz} - 1} \sum_{j = i + 1}^{\text{siz}} \text{dis}(i, j) \]转化:
一般算所有路径的和可以转化为边的贡献。
假设我们算出版本 \(i\) 和 版本 \(j\) 的答案,现在要求版本 \(k\) 的答案, \(i = x_k, j = y_k\)。
-
显然版本内的路径和可以继承,这部分是 \(\text{ans}_i + \text{ans}_j\)。
-
还有经过边 \((u, v, w)\) 的路径:
- 边 \((u, v, w)\) 的贡献是 \(\text{siz}_u \times \text{siz}_v \times w\)。
- 还有 \(i\), \(j\) 版本树内的路径贡献,不妨设 \(f_x(u)\) 为版本 \(x\) 的树中,所有点到 \(u\) 的距离和,
这部分贡献是 \(f_i(u) \times \text{siz}_j + f_j(v) \times \text{siz}_i\)。
就有转移式子:
\[\text{ans}_k = \text{ans}_i + \text{ans}_j + \text{siz}_u \times \text{siz}_v \times w + f_i(u) \times \text{siz}_j + f_j(v) \times \text{siz}_i \]考虑计算 \(f_k(X)\), \(i = x_k, j = y_k\), 不妨设 \(\text{dis}_i(j,k)\) 表示版本 \(i\) 中,\(j\) 到 \(k\) 的路径:
-
\(X\) 在 \(i\) 版本中的树:
- \(i\) 树内的贡献:\(f_i(X)\)
- \(j\) 树内的贡献,要经过边 \((u, v, w)\) 和路径 \(u \to X\): \(f_j(v) + (w + \text{dis}_i(v, X)) \times \text{siz}_j\)
-
\(X\) 在 \(j\) 版本中的树, 同理:
\[f_k(X) = f_j(X) + f_i(u) + (w + \text{dis}_j(v, X)) \times \text{siz}_i \]
考虑计算 \(\text{dis}_k(X, Y)\):
- \(X\), \(Y\) 在同一个版本,直接找子版本:\[\text{dis}_i(X, Y) / \text{dis}_j(X, Y) \]
- \(X\), \(Y\) 在不同版本,假设 \(X\) 在 \(i\), \(Y\) 在 \(j\):
就能算出距离了。
分析
对于 \(f_k(X)\), 最多有 \(k = m\) 个版本,\(X\) 的取值对多有 \(2 \times m\) 个,总的状态就是 \(2 \times m \times m\) 个, 转移时是 \(O(1)\) 的,\(\text{dis}_i(j, k)\)的计算平均是 \(O(1)\) 的,所以计算 \(f_k(X)\) 是 \(O(m^2)\) 的。
计算 \(\text{ans}\) 时, \(f_k(x)\) 会被调用 \(O(m)\) 次.
总的时间复杂度是 \(O(m^3)\)。
代码
//#pragma GCC diagnostic error "-std=c++11"
#include
using namespace std;
using ll = long long;
const int MAXN = 1200010;
const int INF = 0x7fffffff;
const int mod = 1000000007;
template
void Read(T &x) {
x = 0; T f = 1; char a = getchar();
for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
x *= f;
}
ll add(ll a, ll b) {
return ((a + b) % mod + mod) % mod;
}
ll mul(ll a, ll b) {
return (a % mod) * (b % mod) % mod;
}
int m;
int x[MAXN], y[MAXN];
ll u[MAXN], v[MAXN], w[MAXN], siz[MAXN], ans[MAXN];
map > > G;
ll dis(int k, ll X, ll Y) {
if (k == 0) return 0;
if (X == Y) return 0;
if (G[k][X][Y]) return G[k][X][Y];
if (X < siz[x[k]]) {
if (Y < siz[x[k]]) {
return G[k][X][Y] = dis(x[k], X, Y);
} else {
return G[k][X][Y] = add(add(dis(x[k], X, u[k]), dis(y[k], v[k], Y - siz[x[k]])), w[k]);
}
} else {
if (Y < siz[x[k]]) {
return G[k][X][Y] = add(add(dis(x[k], Y, u[k]), dis(y[k], v[k], X - siz[x[k]])), w[k]);
} else {
return G[k][X][Y] = dis(y[k], X - siz[x[k]], Y - siz[x[k]]);
}
}
}
map > F;
ll f(int k, ll X)
{
if (k == 0) return 0;
if (F[k][X]) return F[k][X];
if (X < siz[x[k]]) return F[k][X] = add( add( f(x[k], X), f(y[k], v[k])), mul(siz[y[k]], add(w[k], dis(x[k], u[k], X))));
else return F[k][X] = add( add(f(x[k], u[k]), f(y[k], X - siz[x[k]])), mul(siz[x[k]], add(w[k], dis(y[k], v[k], X - siz[x[k]]))));
}
int main() {
cin >> m;
siz[0] = 1;
for (int i = 1; i <= m; i ++) {
cin >> x[i] >> y[i] >> u[i] >> v[i] >> w[i];
ans[i] = add(add(ans[x[i]], ans[y[i]]), add(mul(mul(siz[x[i]], siz[y[i]]), w[i]), add(mul(f(x[i], u[i]), siz[y[i]]), mul(f(y[i], v[i]), siz[x[i]]))));
siz[i] = siz[x[i]] + siz[y[i]];
cout << ans[i] << endl;
}
return 0;
}
/*
4
0 0 0 0 3
1 0 1 0 4
2 1 0 1 5
1 1 0 0 2
*/