dls的数据结构—启发式合并,dsu on tree


启发式合并

有n个集合,si = {i}
每次将两个集合sx, sy两个集合合并,做n - 1次变成一个集合
启发式合并就是维护每个集合是什么
并查集是相当于只是打了一个标记,查询的时候再把之前的标记更新好
启发式合并是将集合真的给合并在一起,然后删掉被合并的集合
并且保证每次都是小的集合合并到大的集合里面
考虑每个元素的贡献的话,他只会被合并log次,所以总共的时间复杂度是O(nlogn)

#include

using namespace std;
const int N = 1e5+10, M = 1e6+10;
int ans;
// 这里可能换成set map
vector p[M];
int a[N];


int main(){
	int n, m; scanf("%d %d", &n, &m);
	for(int i = 1; i <= n; i ++){
		scanf("%d", &a[i]);
		p[a[i]].push_back(i);
	} 
	// 这里对边界进行了扩充,方便处理
	for(int i = 1; i <= n + 1; i ++){
		ans += a[i] != a[i - 1];
	}

	for(int i = 1; i <= m; i ++){
		int op; scanf("%d", &op);
		if(op == 1){
			int x, y; scanf("%d %d", &x, &y);
			if(x == y) continue;
                        // 这里只维护了p[x],p[y]的正确性,而没有维护a的正确性,因为a = {1, 2, 3} 和a = {3, 2, 1}
                        // 是没有任何区别的,我们只需要让a保持该不一样的还不一样就可以了
			if(p[x].size() > p[y].size()){
                                // O(1)的时间维护两个vector的不同
				p[x].swap(p[y]);
			}
			if(p[y].size() == 0) continue;
			int color = a[p[y][0]];
			auto modify = [&](int item, int color){
				ans -= (a[item] != a[item - 1]) + (a[item] != a[item + 1]);
				a[item] = color;
				ans += (a[item] != a[item - 1]) + (a[item] != a[item + 1]);
			};
			for(auto item :  p[x]){
				modify(item, color);
				p[y].push_back(item);
			}
			p[x].clear();
		}
		else if(op == 2){
			printf("%d\n", ans - 1);
		}
	}
}

// 抽离出来的关键点
// 原题目是:一个序列有不同的颜色,每次可以将x颜色的全部变成y
// p[i]表示颜色为i的节点编号是多少
// 每次合x和y的节点编号就可以了
vector p[M];
// 是一个集合的话就不用改了,直接continue;防止出线意外
if(x == y) continue;
// 这里只维护了p[x],p[y]的正确性,而没有维护序列颜色的正确性,只是保证了序列中颜色的相对不同
if(p[x].size() > p[y].size()){
 // O(1)的时间维护两个vector的不同
    p[x].swap(p[y]);
}
// 下面两句,主要是针对这道题的,就是把x都变成y的颜色,所以要先获取y的颜色
if(p[y].size() == 0) continue;
int color = a[p[y][0]];
auto modify = [&](int item, int color){
    ans -= (a[item] != a[item - 1]) + (a[item] != a[item + 1]);
    a[item] = color;
    ans += (a[item] != a[item - 1]) + (a[item] != a[item + 1]);
};
// 合并集合
for(auto item :  p[x]){
    // 合并进入一个元素带来的影响
    modify(item, color);
    p[y].push_back(item);
}
// 最好写上这个清空吧
p[x].clear();

// 边合并,边解决查询
// 对询问进行离线
// 每个集合用并查集维护,该集合具体有什么,存放到并查集的代表元素里面
// 其实是维护了两个集合,一个当前集合里面有哪些查询,另一个是集合里面有哪些点集
// 使用一个belong[i]:存放i这个元素的集合在哪里
#include

using namespace std;
const int N = 2e5+10;
array edges[N];
vector< array > que[N];
vector< int > fa[N];

int p[N];
int belong[N];
int ans[N];

int find(int x){
	if(x != p[x]) return p[x] = find(p[x]);
	return p[x];
}

int main(){
	int n, q;
	for(int i = 1; i <= n; i ++){
		p[i] = i;
		belong[i] = i;
		fa[i].push_back(i);
	} 
	for(int i = 1; i <= n - 1; i ++){
		int a, b, c; scanf("%d %d %d", &a, &b, &c);
		edges[i] = {c, a, b};
	}
	sort(edges + 1, edges + 1 + n);
	reverse(edges + 1, edges + 1 + n);
	for(int i = 1; i <= q; i ++){
		int a, b; scanf("%d %d", &a, &b);
		que[a].push_back({b, i});
		que[b].push_back({a, i});
	}

	for(int i = 1; i <= n - 1; i ++){
		int u = find(edges[i][1]), v = find(edges[i][2]);
		if(fa[u].size() > fa[v].size()){
			swap(u, v);
		}
		for(auto [node, id] : que[u]){
			if(ans[id] != 0) continue;
			if(belong[node] == v){
				ans[id] = edges[i][0];
			}
			else{
				que[v].push_back({node, id});
			}
		}
		que[u].clear();
		for(auto c : fa[u]){
			fa[v].push_back(c);
			belong[c] = v; 
		}
		fa[u].clear();
		p[u] = v;
	}
	for(int i = 1; i <= n; i ++) printf("%d\n", ans[i]);
	return 0;
}

// 这个问题是求树上路径的最小值
// 首先对边从大到小进行合并,这样的话,查询的点在左右两个集合中话答案就是当前的边的长度
array edges[N];
// 对查询进行离线
vector< array > que[N];
vector< int > fa[N];
// p就是并查集的数组
int p[N];
// belong[i]:表示i这个节点属于哪一个结合,这个主要为了O(1)判断询问的另一个端点是否在另一个集合里面
int belong[N];
int ans[N];

// 我们需要做的就是对一个包含节点的集合合并,并且对一个包含询问的节点进行合并
for(int i = 1; i <= n - 1; i ++){
    // 将每个集合挂在并查集的代表元素那里
    int u = find(edges[i][1]), v = find(edges[i][2]);
    if(fa[u].size() > fa[v].size()){
        swap(u, v);
    }
    // 能解决的询问就解决,不能解决的就合并到另一个集合
    for(auto [node, id] : que[u]){
        if(ans[id] != 0) continue;
        if(belong[node] == v){
            ans[id] = edges[i][0];
        }
        else{
            que[v].push_back({node, id});
        }
    }
    que[u].clear();
    // 合并节点集合
    for(auto c : fa[u]){
        fa[v].push_back(c);
        belong[c] = v; 
    }
    fa[u].clear();
    // 别忘了对并查集合并
    p[u] = v;
}

分治的时候分治地不是很均匀,导致分治的复杂度变成了O(n^2)
靠边的元素希望能快找到他,中间元素的话就比较无所谓
合并的代价是较小的一个,与启发式合并类似,这样的话可以使得复杂度变成O(n)
#include

using namespace std;
const int N = 2e5+10;
int A[N], nxt[N], pre[N];

bool get(int l, int r){
	if(l >= r) return true;
	for(int i = l, j = r; i <= j; i ++, j --){
		if(pre[i] < l && nxt[i] > r) return get(l, i - 1) && get(i + 1, r);
		if(pre[j] < l && nxt[j] > r) return get(l, j - 1) && get(j + 1, r);
	}
	return false;
}

void solve(){
	int n; scanf("%d", &n);
	for(int i = 1; i <= n; i ++) scanf("%d", &A[i]);
	map h;
	for(int i = 1; i <= n; i++){
		if(h.count(A[i])) pre[i] = h[A[i]];
		else pre[i] = 0;
		h[A[i]] = i;
	}
	h.clear();
	for(int i = n; i >= 1; i --){
		if(h.count(A[i])) nxt[i] = h[A[i]];
		else nxt[i] = n + 1;
		h[A[i]] = i;
	}
	if(get(1, n)) puts("non-boring");
	else puts("boring");

}

int main(){
	int T; cin >> T;
	while(T--){
		solve();
	}
}

// 启发式分治
// 分治的时候可能是不均匀的,而且不均匀的分治在每层里面还是n的
// 这样就会被卡成O(n^2)
// 所以我们希望的就是要么你每层的时间花的多点,然后分治的均匀些
// 要么你每层的时间话的少一些,分治的层数可以多些
// 这样可以O(nlogn)
// 在每层内部所话的时候是2/n小的哪部分的即可满足条件

bool solve(int l, int r){
    if(l >= r) return true;
    for(int i = l, j = r; i <= j; i ++, j --){
        if(pre[i] < l && nxt[i] > r) return solve(l, i - 1) && solve(i + 1, r);
        if(pre[j] < l && nxt[j] > r) return solve(l, j - 1) && solve(j + 1, r);
    }
    return false;
}

dsu on tree

维护子树的信息,在合并两个子树的信息的时候,将小的集合合并到大的集合里面
对于每个点,都找到一个最大的儿子(重儿子),其他叫做轻儿子
先把本身并到重儿子里面,然后依次把轻儿子合并到中之前集合里面
想法1:u的集合直接从重二子继承过来,依次把轻儿子合并过去
轻儿子合并的时候要for轻儿子子树里面所有的节点
支持单个节点的加入和删除操作就用dsu on tree

#include

using namespace std;
const int N = 1e5+10;
int c[N];
int h[N], ne[2 * N], e[2 * N], idx;
void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

// hs表示重儿子是哪个儿子
int l[N], r[N], id[N], sz[N], hs[N], tot;
void dfs_init(int u, int fa){
    l[u] = ++tot;
    id[tot] = u;
    sz[u] = 1;
    hs[u] = -1;
    for(int i = h[u]; i != - 1; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        dfs_init(j, u);
        sz[u] += sz[j];
        if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
    }
    r[u] = tot;
}


int cnt[N], maxcnt;
long long maxnsum, sumcnt, ans[N];
void dfs_slove(int u, int fa, bool keep){
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j ==  fa || j == hs[u]) continue;
        dfs_slove(j, u, false);
    }

    if(hs[u] != -1){
        dfs_slove(hs[u], u, true);
    }

    auto add = [&](int x){
        x = c[x];
        cnt[x] ++;
        if(cnt[x] > maxcnt ) maxcnt = cnt[x], sumcnt = x;
        else if(cnt[x] == maxcnt ) sumcnt += x;
    };
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j == fa || j == hs[u]) continue;
        for(int x = l[j]; x <= r[j]; x ++){
            add(id[x]);
        }
    }
    add(u);
    ans[u] = sumcnt;
    auto del = [&](int x){
        x = c[x];
        cnt[x] --;
    };
    
    if(!keep){
        maxcnt = 0, sumcnt = 0;
        for(int x = l[u]; x <= r[u]; x ++) del(id[x]);
    }

}

int main(){
    int n; scanf("%d", &n);
    memset(h, -1, sizeof h);
    for(int i = 1; i <= n; i ++) scanf("%d", &c[i]);
    for(int i = 1; i <= n - 1; i ++){
        int a, b; scanf("%d %d", &a, &b);
        add(a, b), add(b, a);
    }
    dfs_init(1, 0);
    dfs_slove(1, 0, false);
    for(int i = 1; i <= n; i ++) printf("%lld ", ans[i]);
    puts("");

}


// 问题求解每个子树的众数,有多个的话,将众数的数值相加
int l[N], r[N], id[N], sz[N], hs[N], tot;
// init中,我们将dfs序求出来,节点i的重儿子用hs[i]表示,没有的话就是-1
// 当然这里可以初始化些其他的东西
void dfs_init(int u, int fa){
    l[u] = ++tot;
    id[tot] = u;
    sz[u] = 1;
    hs[u] = -1;
    for(int i = h[u]; i != - 1; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        dfs_init(j, u);
        sz[u] += sz[j];
        if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
    }
    r[u] = tot;
}


int cnt[N], maxcnt;
long long maxnsum, sumcnt, ans[N];
// keep表示这个节点为根的子树的信息是不是要保留,重链保留,轻链不保留
void dfs_slove(int u, int fa, bool keep){
    // 轻链的话,我们直接递归求解就可以了,最后的信息不用保留
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j ==  fa || j == hs[u]) continue;
        dfs_slove(j, u, false);
    }
    // 重链的话,我们同样递归求解,求解得到的信息需要保留
    if(hs[u] != -1){
        dfs_slove(hs[u], u, true);
    }

    auto add = [&](int x){
        x = c[x];
        cnt[x] ++;
        if(cnt[x] > maxcnt ) maxcnt = cnt[x], sumcnt = x;
        else if(cnt[x] == maxcnt ) sumcnt += x;
    };
    // 遍历轻链的每个节点,将其加入,关键是写好加入维护的信息有什么影响,对其更新就好了
    // 这里需要注意的是add(id[x]), 并不是add(x)
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j == fa || j == hs[u]) continue;
        for(int x = l[j]; x <= r[j]; x ++){
            add(id[x]);
        }
    }
    // 将根节点加入,并且这里也要记得对维护的信息进行更新
    add(u);
    ans[u] = sumcnt;
    auto del = [&](int x){
        x = c[x];
        cnt[x] --;
    };
    // 如果信息不需要保留的话,把信息清空就可以了,这里其实是比较简单的
    if(!keep){
        maxcnt = 0, sumcnt = 0;
        for(int x = l[u]; x <= r[u]; x ++) del(id[x]);
    }

}

// 查询两个集合里面各挑一个元素,是否满足某种条件
// 将其他一个进行hash,遍历另一个进行查询nlogn
#include

using namespace std;
const int N = 2e5+10;
int c[N], n, k;
int h[N], ne[2 * N], e[2 * N], w[2 * N], idx;
void add(int a, int b, int c){
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

// hs表示重儿子是哪个儿子
int l[N], r[N], id[N], sz[N], hs[N], tot;
long long dep1[N], dep2[N];
void dfs_init(int u, int fa){
    l[u] = ++tot;
    id[tot] = u;
    sz[u] = 1;
    hs[u] = -1;
    for(int i = h[u]; i != - 1; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        dep1[j] = dep1[u] + 1;
        dep2[j] = dep2[u] + w[i];
        dfs_init(j, u);
        sz[u] += sz[j];
        if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
    }
    r[u] = tot;
}


map val;
long long ans = 1ll << 60;
void dfs_slove(int u, int fa, bool keep){
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j ==  fa || j == hs[u]) continue;
        dfs_slove(j, u, false);
    }

    if(hs[u] != -1){
        dfs_slove(hs[u], u, true);
    }

    auto query = [&](int son){
        long long d2 = k + 2 * dep2[u] - dep2[son];
        if(val.count(d2)){
            ans = min(ans, val[d2] + dep1[son] - 2 * dep1[u]);
        } 
    };

    auto add = [&](int son){
        if(val.count(dep2[son])) val[dep2[son]] = min(val[dep2[son]], dep1[son]);
        else val[dep2[son]] = dep1[son];
    };

    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j == fa || j == hs[u]) continue;
        for(int x = l[j]; x <= r[j]; x ++){
            query(id[x]);
        }
        for(int x = l[j]; x <= r[j]; x ++){
            add(id[x]);
        }
    }
    query(u);
    add(u);

    if(!keep){
        val.clear();
    }

}

int main(){
    scanf("%d %d", &n, &k);
    memset(h, -1, sizeof h);
    for(int i = 1; i <= n - 1; i ++){
        int a, b, c; scanf("%d %d %d", &a, &b, &c);
        a ++, b ++;
        add(a, b, c), add(b, a, c);
    }
    dfs_init(1, 0);
    dfs_slove(1, 0, false);
    if(ans >= (1ll << 60) / 2) cout << -1 << endl;
    else cout << ans << endl;

}