【YBT2022寒假Day3 C】毒瘤染色(LCT)(圆方树)(容斥)
毒瘤染色
题目链接:YBT2022寒假Day3 C
题目大意
要你在线实现一个操作:
一开始有 n 个点,没有边,然后操作会给你一条边。
如果保证加了之后这个图还是沙漠就加上。
然后每次加完边之后问你一开始所有点都是白色做 k 次每次随机选一个点(可能白色)的点把它变成白色,然后问你分别保留黑白点是的连通块个数的期望值。
(有部分分只用求保留白点的)
思路
考虑分别处理加边的操作和询问。
那沙漠一般来讲要么就用圆方树,要么就直接拆链搞,那这个应该是要前面的那个。
你考虑你其实就是要看两个点之间是否连通,这个好搞,和两个点的路径点是否有环,而且会出现新的环。
那我们不如用 LCT 来弄,记录方点圆点个数(方点是一个环的代表点,圆点就是普通点)
然后如果路径中有方点就是有环,就不能加边。(这里没有必要圆方点交错的那种)
然后考虑询问,考虑连通块如何表示:
在是普通的没有环的图上,它可以表示成点数 - 边数。
那现在有环,但是是沙漠,可以表示成点数 - 边数 + 环数。
那由于每次选的概率每个点相同,每个点是黑色或者白色的概率是相同的。
所以边的概率就是它连接的两个点都存在的概率,环存在的概率就是环上所有点都存在的概率。
那我们设 \(w_x\) 为选 \(x\) 个点,都是白点的概率, \(b_x\) 为都是黑点的概率。然后设 \(f_x=w_x+\omega b_x\)。
然后答案就是 \(n*f_1-m*f_2+\sum f_{sz_i}\)(\(sz_i\) 为每个环的大小)。
然后你在连边的时候更新 \(m,sz_i\) 即可。(一个是加边了就更新,一个是出现了新的环就更新)
然后接着问题就是 \(w_x,b_x\) 怎么求。
\(w_x\) 很好求,就是不被选到:\((\dfrac{n-x}{n})^k\)。
那接着就是求 \(b_x\),考虑容斥求:
\(b_x=\sum\limits_{i=0}^x(-1)^i(\dfrac{n-i}{n})^kC_x^i\)
然后就好了。
代码
#include
#include
#include
#define ll long long
#define mo 998244353
using namespace std;
ll n, q, k, w, m;
ll x, y, lst, tot;
vector tmp;
ll b1, b2, invn, f2;
ll jc[100001], inv[100001];
ll b[100001], ans;
ll ksm(ll x, ll y) {
ll re = 1;
while (y) {
if (y & 1) re = re * x % mo;
x = x * x % mo;
y >>= 1;
}
return re;
}
ll C(ll n, ll m) {
if (n < m || m < 0) return 0;
return jc[n] * inv[m] % mo * inv[n - m] % mo;
}
struct LCT {
ll ls[400001], rs[400001], cir[400001];
ll sz[400001], cirsz[400001], fa[400001];
bool lzyc[400001];
void Init() {
for (ll i = 1; i <= n; i++) sz[i] = cir[i] = cirsz[i] = 1;
}
bool nrt(ll x) {
return ls[fa[x]] == x || rs[fa[x]] == x;
}
bool lrs(ll x) {
return ls[fa[x]] == x;
}
void up(ll x) {
sz[x] = sz[ls[x]] + sz[rs[x]] + 1;
cirsz[x] = cirsz[ls[x]] + cirsz[rs[x]] + cir[x];
}
void downc(ll now) {
lzyc[now] ^= 1; swap(ls[now], rs[now]);
}
void down(ll now) {
if (lzyc[now]) {
if (ls[now]) downc(ls[now]);
if (rs[now]) downc(rs[now]);
lzyc[now] = 0;
}
}
void down_line(ll now) {
if (nrt(now)) down_line(fa[now]);
down(now);
}
void rotate(ll x) {
ll y = fa[x], z = fa[y];
ll b = lrs(x) ? rs[x] : ls[x];
if (z && nrt(y)) (lrs(y) ? ls[z] : rs[z]) = x;
if (lrs(x)) rs[x] = y, ls[y] = b;
else ls[x] = y, rs[y] = b;
fa[x] = z; fa[y] = x;
if (b) fa[b] = y;
up(y);
}
void Splay(ll x) {
down_line(x);
while (nrt(x)) {
if (nrt(fa[x])) {
if (lrs(x) == lrs(fa[x])) rotate(fa[x]);
else rotate(x);
}
rotate(x);
}
up(x);
}
void access(ll x) {
ll lst = 0;
for (; x; x = fa[x]) {
Splay(x);
rs[x] = lst; up(x);
lst = x;
}
}
void make_root(ll x) {
access(x);
Splay(x);
downc(x);
}
ll find_root(ll now) {
access(now);
Splay(now);
down(now);
while (ls[now]) {
now = ls[now]; down(now);
}
return now;
}
ll select(ll x, ll y) {
make_root(x);
access(y);
Splay(y);
return y;
}
bool link(ll x, ll y) {
if (find_root(x) == find_root(y)) return 0;
make_root(x);
fa[x] = y;
return 1;
}
void get_all(ll now) {
tmp.push_back(now);
if (ls[now]) get_all(ls[now]);
if (rs[now]) get_all(rs[now]);
}
}T;
ll get_B(ll x) {
ll re = 0, di = 1;
for (ll i = 0; i <= x; i++) {
(re += di * C(x, i) % mo * ksm((n - i) * invn % mo, k) % mo) %= mo;
di = mo - di;
}
return re;
}
ll W(ll x) {
return ksm((n - x) * invn % mo, k);
}
int main() {
// freopen("graph.in", "r", stdin);
// freopen("graph.out", "w", stdout);
scanf("%lld %lld %lld %lld", &n, &q, &k, &w); invn = ksm(n, mo - 2); tot = n;
T.Init();
jc[0] = 1; for (ll i = 1; i <= n; i++) jc[i] = jc[i - 1] * i % mo;
inv[0] = inv[1] = 1; for (ll i = 2; i <= n; i++) inv[i] = inv[mo % i] * (mo - mo / i) % mo;
for (ll i = 1; i <= n; i++) inv[i] = inv[i - 1] * inv[i] % mo;
b1 = get_B(1); b2 = get_B(2);
ans = 1ll * n * (W(1) + w * b1) % mo;
f2 = (W(2) + w * b2) % mo;
while (q--) {
scanf("%lld %lld", &x, &y); x ^= lst; y ^= lst;
if (x == y) {
printf("%lld\n", lst); continue;
}
if (!T.link(x, y)) {
ll now = T.select(x, y), sz = T.sz[now];
if (T.sz[now] == T.cirsz[now]) {
ans = (ans - f2 + mo) % mo;
T.sz[++tot] = 1;
tmp.clear(); T.get_all(now);
for (ll i = 0; i < sz; i++) {
T.fa[tmp[i]] = T.ls[tmp[i]] = T.rs[tmp[i]] = 0; T.sz[tmp[i]] = T.cir[tmp[i]] = T.cirsz[tmp[i]] = 1;
T.link(tmp[i], tot);
}
ans = (ans + W(sz)) % mo;
if (w) ans = (ans + get_B(sz)) % mo;
}
}
else ans = (ans - f2 + mo) % mo;
lst = ans;
printf("%lld\n", lst);
}
return 0;
}