[学习笔记] 线段树合并


线段树合并,从名字上就可以看出,它是合并两颗线段树的算法,其核心就是动态开点和 merge 函数,而 merge 函数主要有两种写法,两种写法都对应这不同的清况:

首先我们假设有两棵要合并的线段树1和2,相应的结点分别为a和b

  1. 把b合并到a上
void merge(int &a,int &b,int l,int r) {
	if(!a||!b){a+=b;return;}//只要有一个点是残缺的就都看作合并到a
	if(l==r) {maxx[a]+=maxx[b];res[a]=l;return;}
	int mid=(l+r)>>1;
	merge(ls[a],ls[b],l,mid),merge(rs[a],rs[b],mid+1,r);
    maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
    res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}

这种情况适用于离线,因为我们在操作的时候可能会破坏线段树2的结构

  1. 新开一个结点
int merge(int a,int b,int x,int y) {
    if(!a||!b){a+=b;return a;}
	int root=++tot;
	if(l==r) {maxx[root]=maxx[a]+maxx[b];res[root]=l;return root;}
	int mid=(l+r)>>1;
	ls[root]=merge(ls[a],ls[b],l,mid);
    rt[root]=merge(rs[a],rs[b],mid+1,r);
    maxx[root]=max(maxx[ls[root]],maxx[rs[root]]);
    res[root]=maxx[root]==maxx[ls[root]]?res[ls[root]]:res[rs[root]];
}

这种方法的优点就是支持在线,但是比较费空间

例题

P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并

我们可以在每个结点都开一棵线段树,维护这个房子里每种救济粮的个数(权值线段树),那么再运用树上差分的知识,其实我们就只需要修改四个点,最后一个点的答案就是它所在子树的线段树合并之后的答案

主要说说线段树合并,我们可以边 dfs 边进行合并操作,就是把儿子结点的线段树合并到父亲结点,合并的时候,两棵树其实是同步的我们需要判断这两棵树是否都有这个结点,如果有一个没有就直接接到a上,如果两个都有就把b上维护的值合并到a上

code
#include 
using namespace std;
int read(){
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
const int MAXN=1e5,N=1e5+10;
int ver[N<<1],tot,nxt[N<<1],head[N<<1],n,m;
int fa[N],dfn[N],dfstime,sz[N],top[N],son[N],dep[N],ans[N];
int rs[N*67],ls[N*67],maxx[N*67],res[N*67],rt[N];
void add(int x,int y){
    ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;
    ver[++tot]=x,nxt[tot]=head[y],head[y]=tot;
}
void dfs1(int u,int father){
    fa[u]=father;sz[u]=1;dep[u]=dep[father]+1;
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(v==fa[u]) continue;
	dfs1(v,u);
	sz[u]+=sz[v];
	if(sz[v]>sz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int anc){
    top[u]=anc;
    if(son[u]) dfs2(son[u],anc);
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(top[v]) continue;
	dfs2(v,v);
    }
}
int LCA(int x,int y){
    while(top[x]^top[y]){
	if(dep[top[x]]dep[y]) swap(x,y);
    return x;
}
void merge(int &a,int &b,int l,int r){
    if(!a||!b){a+=b;return;}
    if(l==r){maxx[a]+=maxx[b];res[a]=l;return;}
    int mid=(l+r)>>1;
    merge(ls[a],ls[b],l,mid);
    merge(rs[a],rs[b],mid+1,r);
    maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
    res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}
void upd(int &k,int l,int r,int x,int val){
    if(!k) k=++tot;
    if(l==r){maxx[k]+=val;res[k]=l;return;}
    int mid=(l+r)>>1;
    if(x<=mid) upd(ls[k],l,mid,x,val);
    else upd(rs[k],mid+1,r,x,val);
    maxx[k]=max(maxx[ls[k]],maxx[rs[k]]);
    res[k]=maxx[k]==maxx[ls[k]]?res[ls[k]]:res[rs[k]];
}
void dfs(int u,int father){
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(v==fa[u]) continue;
	dfs(v,u);
	merge(rt[u],rt[v],1,MAXN);
    }
    if(maxx[rt[u]]) ans[u]=res[rt[u]];
}
int main(){
    n=read();m=read();
    for(int i=1;i
using namespace std;
int read(){
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
const int MAXN=1e5,N=1e5+10;
int ver[N<<1],tot,nxt[N<<1],head[N<<1],n,m;
int fa[N],dfn[N],dfstime,sz[N],top[N],son[N],dep[N],ans[N];
int rs[N*67],ls[N*67],maxx[N*67],res[N*67],rt[N];
void add(int x,int y){
    ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;
    ver[++tot]=x,nxt[tot]=head[y],head[y]=tot;
}
void dfs1(int u,int father){
    fa[u]=father;sz[u]=1;dep[u]=dep[father]+1;
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(v==fa[u]) continue;
	dfs1(v,u);
	sz[u]+=sz[v];
	if(sz[v]>sz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int anc){
    top[u]=anc;
    if(son[u]) dfs2(son[u],anc);
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(top[v]) continue;
	dfs2(v,v);
    }
}
int LCA(int x,int y){
    while(top[x]^top[y]){
	if(dep[top[x]]dep[y]) swap(x,y);
    return x;
}
void merge(int &a,int &b,int l,int r){
    if(!a||!b){a+=b;return;}
    if(l==r){maxx[a]+=maxx[b];res[a]=l;return;}
    int mid=(l+r)>>1;
    merge(ls[a],ls[b],l,mid);
    merge(rs[a],rs[b],mid+1,r);
    maxx[a]=max(maxx[ls[a]],maxx[rs[a]]);
    res[a]=maxx[a]==maxx[ls[a]]?res[ls[a]]:res[rs[a]];
}
void upd(int &k,int l,int r,int x,int val){
    if(!k) k=++tot;
    if(l==r){maxx[k]+=val;res[k]=l;return;}
    int mid=(l+r)>>1;
    if(x<=mid) upd(ls[k],l,mid,x,val);
    else upd(rs[k],mid+1,r,x,val);
    maxx[k]=max(maxx[ls[k]],maxx[rs[k]]);
    res[k]=maxx[k]==maxx[ls[k]]?res[ls[k]]:res[rs[k]];
}
void dfs(int u,int father){
    for(int i=head[u];i;i=nxt[i]){
	int v=ver[i];
	if(v==fa[u]) continue;
	dfs(v,u);
	merge(rt[u],rt[v],1,MAXN);
    }
    if(maxx[rt[u]]) ans[u]=res[rt[u]];
}
int main(){
    n=read();m=read();
    for(int i=1;i

相关