点分治学习笔记
点分治学习笔记
模板题[洛谷P3806]
题意:给定一棵有n个点的树,询问树上距离为k的点对是否存在。
做法:对于一个点\(u\),树上所有的路径可以分为两类,一类是经过点\(u\),另一类是没有经过点\(u\),即整条路径位于\(u\)的某个子树中。那么我们就可以对于点\(u\)统计出经过他的路径是否可以构成\(k\),然后删去点\(u\),对于每个子树的挑一个点作为根\(v\),重复同样的操作。这样我们就统计出了所有的路径。
然后我们注意到,如果这棵树是一条链,最坏的情况复杂度会下降为\(O(n^2)\),为了解决这个问题,我们每次选取当前这棵树的重心作为根来分治,就可以将最坏复杂度降为\(O(nlogn)\),为了写起来方便我的代码多了一个\(log\)
Code:
#include
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define pb push_back
#define Pii pair
#define x first
#define y second
const int N = 10005;
template inline void read(T &x) {
x = 0; T f = 1; char c = getchar();
while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
x *= f;
}
using namespace std;
int n, m, K[111], Ans[111];
struct edge{int e, w, nxt;} E[N << 1];
int h[N], cc;
void add(int u, int v, int w) {
E[cc].e = v; E[cc].w = w;
E[cc].nxt = h[u]; h[u] = cc; ++cc;
}
int used[N], sz[N], mxp[N];
map dep;
set S;
int idx, MN;
void dfs(int u, int pre, int num) {
sz[u] = 1; mxp[u] = 0;
for(int i = h[u]; ~i; i = E[i].nxt) if(!used[E[i].e] && E[i].e != pre){
int v = E[i].e;
dfs(v, u, num);
sz[u] += sz[v];
mxp[u] = max(mxp[u], sz[v]);
}
mxp[u] = max(mxp[u], num-sz[u]);
if(mxp[u] < MN) MN = mxp[u], idx = u;
}
int fdrt(int u,int sum) {
idx = 0, MN = __INT_MAX__;
dfs(u,0,sum);
return idx;
}
void bfs(int u, int w) {
dep[u] = w;
queue q; q.push(u);
while(!q.empty()) {
int u = q.front(); q.pop();
for(int i = h[u]; ~i ; i = E[i].nxt) if(!used[E[i].e] && dep.find(E[i].e) == dep.end()) {
int v = E[i].e;
dep[v] = dep[u] + E[i].w;
q.push(v);
}
}
}
int M[N];
void solve(int u) {
used[u] = 1; S.clear(); S.insert(0);
for(int i = h[u]; ~i; i = E[i].nxt) if(!used[E[i].e]) {
int v = E[i].e, w = E[i].w;
dep.clear();
bfs(v,w);
for(auto A: dep) {
for(int j = 1; j <= m; ++j) {
if(S.find(K[j] - A.y) != S.end()) Ans[j] |= 1;
}
}
for(auto A: dep) S.insert(A.y);
M[v] = dep.size();
}
for(int i = h[u]; ~i; i = E[i].nxt) if(!used[E[i].e]) {
solve( fdrt(E[i].e, M[E[i].e]) );
}
}
int main() {
read(n); read(m); int u, v, w; memset(h, -1, sizeof(h));
rep(i,2,n) read(u), read(v), read(w), add(u,v,w), add(v,u,w);
rep(i,1,m) read(K[i]);
int rt = fdrt(1,n);
solve(rt);
rep(i,1,m) puts(Ans[i] ? "AYE" : "NAY");
}