[学习笔记] 线段树合并
线段树合并,从名字上就可以看出,它是合并两颗线段树的算法,其核心就是动态开点和 merge
函数,而 merge
函数主要有两种写法,两种写法都对应这不同的清况:
首先我们假设有两棵要合并的线段树1和2,相应的结点分别为a和b
- 把b合并到a上
void merge(int &a,int &b,int l,int r) {
if(!a||!b){a+=b;return;}//只要有一个点是残缺的就都看作合并到a
if(l==r) {maxx[a]+=maxx[b];res[a]=l;return;}
int mid=(l+r)>>1;
merge(ls[a],ls[b],l,mid),merge(rs[a],rs[b],mid+1,r);
maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}
这种情况适用于离线,因为我们在操作的时候可能会破坏线段树2的结构
- 新开一个结点
int merge(int a,int b,int x,int y) {
if(!a||!b){a+=b;return a;}
int root=++tot;
if(l==r) {maxx[root]=maxx[a]+maxx[b];res[root]=l;return root;}
int mid=(l+r)>>1;
ls[root]=merge(ls[a],ls[b],l,mid);
rt[root]=merge(rs[a],rs[b],mid+1,r);
maxx[root]=max(maxx[ls[root]],maxx[rs[root]]);
res[root]=maxx[root]==maxx[ls[root]]?res[ls[root]]:res[rs[root]];
}
这种方法的优点就是支持在线,但是比较费空间
例题
P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并
我们可以在每个结点都开一棵线段树,维护这个房子里每种救济粮的个数(权值线段树),那么再运用树上差分的知识,其实我们就只需要修改四个点,最后一个点的答案就是它所在子树的线段树合并之后的答案
主要说说线段树合并,我们可以边 dfs
边进行合并操作,就是把儿子结点的线段树合并到父亲结点,合并的时候,两棵树其实是同步的我们需要判断这两棵树是否都有这个结点,如果有一个没有就直接接到a上,如果两个都有就把b上维护的值合并到a上
code
#include
using namespace std;
int read(){
int x=0,f=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=1e5,N=1e5+10;
int ver[N<<1],tot,nxt[N<<1],head[N<<1],n,m;
int fa[N],dfn[N],dfstime,sz[N],top[N],son[N],dep[N],ans[N];
int rs[N*67],ls[N*67],maxx[N*67],res[N*67],rt[N];
void add(int x,int y){
ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;
ver[++tot]=x,nxt[tot]=head[y],head[y]=tot;
}
void dfs1(int u,int father){
fa[u]=father;sz[u]=1;dep[u]=dep[father]+1;
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa[u]) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int anc){
top[u]=anc;
if(son[u]) dfs2(son[u],anc);
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(top[v]) continue;
dfs2(v,v);
}
}
int LCA(int x,int y){
while(top[x]^top[y]){
if(dep[top[x]]dep[y]) swap(x,y);
return x;
}
void merge(int &a,int &b,int l,int r){
if(!a||!b){a+=b;return;}
if(l==r){maxx[a]+=maxx[b];res[a]=l;return;}
int mid=(l+r)>>1;
merge(ls[a],ls[b],l,mid);
merge(rs[a],rs[b],mid+1,r);
maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}
void upd(int &k,int l,int r,int x,int val){
if(!k) k=++tot;
if(l==r){maxx[k]+=val;res[k]=l;return;}
int mid=(l+r)>>1;
if(x<=mid) upd(ls[k],l,mid,x,val);
else upd(rs[k],mid+1,r,x,val);
maxx[k]=max(maxx[ls[k]],maxx[rs[k]]);
res[k]=maxx[k]==maxx[ls[k]]?res[ls[k]]:res[rs[k]];
}
void dfs(int u,int father){
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa[u]) continue;
dfs(v,u);
merge(rt[u],rt[v],1,MAXN);
}
if(maxx[rt[u]]) ans[u]=res[rt[u]];
}
int main(){
n=read();m=read();
for(int i=1;i
using namespace std;
int read(){
int x=0,f=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=1e5,N=1e5+10;
int ver[N<<1],tot,nxt[N<<1],head[N<<1],n,m;
int fa[N],dfn[N],dfstime,sz[N],top[N],son[N],dep[N],ans[N];
int rs[N*67],ls[N*67],maxx[N*67],res[N*67],rt[N];
void add(int x,int y){
ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;
ver[++tot]=x,nxt[tot]=head[y],head[y]=tot;
}
void dfs1(int u,int father){
fa[u]=father;sz[u]=1;dep[u]=dep[father]+1;
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa[u]) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int anc){
top[u]=anc;
if(son[u]) dfs2(son[u],anc);
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(top[v]) continue;
dfs2(v,v);
}
}
int LCA(int x,int y){
while(top[x]^top[y]){
if(dep[top[x]]dep[y]) swap(x,y);
return x;
}
void merge(int &a,int &b,int l,int r){
if(!a||!b){a+=b;return;}
if(l==r){maxx[a]+=maxx[b];res[a]=l;return;}
int mid=(l+r)>>1;
merge(ls[a],ls[b],l,mid);
merge(rs[a],rs[b],mid+1,r);
maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}
void upd(int &k,int l,int r,int x,int val){
if(!k) k=++tot;
if(l==r){maxx[k]+=val;res[k]=l;return;}
int mid=(l+r)>>1;
if(x<=mid) upd(ls[k],l,mid,x,val);
else upd(rs[k],mid+1,r,x,val);
maxx[k]=max(maxx[ls[k]],maxx[rs[k]]);
res[k]=maxx[k]==maxx[ls[k]]?res[ls[k]]:res[rs[k]];
}
void dfs(int u,int father){
for(int i=head[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa[u]) continue;
dfs(v,u);
merge(rt[u],rt[v],1,MAXN);
}
if(maxx[rt[u]]) ans[u]=res[rt[u]];
}
int main(){
n=read();m=read();
for(int i=1;i