CodeForces - 917D Stranger Trees


Description

\(\mathcal{P}\text{ortal.}\)

Solution

Method 1:矩阵树定理

麻了,之前还做过一道把边权写成多项式的矩阵树定理题,结果这题还是不会做 qwq.

设给定边的边权为 \(x\),其余边的边权为 \(1\),那么求出所有生成树的边权之积的和就是一个 \(n-1\) 次的多项式(考虑求行列式乘了 \(n-1\) 次),那么 \(x^k\) 项对应的系数就是有 \(k\) 条给定边的生成树个数。

如果直接用多项式进行高斯消元求行列式,复杂度则有 \(\mathcal O(n^4\log n)\)(最里面的循环有多项式乘法,所以带 \(\mathcal O(n\log n)\) 的复杂度),而且目测不会很好写。事实上可以拉格朗日插值,带入 \(x=1,2,\dots, n\) 计算 \(n\) 次行列式求得 \(n\) 个函数值,再将多项式依次降次,取 \(f(0)\) 作为系数,复杂度降到了 \(\mathcal O(n^4)\).

Method 2:容斥

\(\mathcal{C}\text{onclusion}\):一个 \(n\) 个点的无向图,若连通块个数为 \(k\),大小分别为 \(s_1,s_2,\dots,s_k\),则用 \(k-1\) 条边将 \(k\) 个连通块连起来的方案数为 \(n^{k-2}\cdot \prod_{i=1}^k s_i\).

证明就直接 \(\text{oi-wiki}\) 了。

\(g(i)\)至少\(i\) 条给定边的方案数,一个非常神奇的等价是这等价于在 给定树 上取 \(n-i\) 个连通块(也就是取 \(i\) 条给定边),其余边乱选的方案数!如果算出 \(g(i)\),那么代表恰好的 \(f(i)\) 用二项式反演也就可以算出来了。

\(g(i,j,k)\) 为以 \(i\) 为根的子树划分了 \(j\) 个连通块,当前点所在块大小为 \(k\)。由上文结论,我们只用维护 \(\prod_{i=1}^k s_i\)。那么 \(\mathcal O(n^3)\) 可以简单转移。

事实上还可以做到更优,类似 的转化,我们考虑 \(\prod_{i=1}^k s_i\) 的组合意义:在每个连通块中选一个点的方案数。设 \(g(i,j,0/1)\) 表示以 \(i\) 为根的子树划分了 \(j\) 个连通块,当前点所在块是否已选点的方案数,复杂度优化到 \(\mathcal O(n^2)\).

Code

Method 1:矩阵树定理

犯了很多神必错误,感觉自己像个脑瘫。

# include 
# include 
# define print(x,y) write(x), putchar(y)

template 
inline T read(const T sample) {
    T x=0; char s; bool f=0;
    while(!isdigit(s=getchar())) f|=(s=='-');
    for(; isdigit(s); s=getchar()) x=(x<<1)+(x<<3)+(s^48);
    return f? -x: x;
}
template 
inline void write(T x) {
    static int writ[50], w_tp=0;
    if(x<0) putchar('-'), x=-x;
    do writ[++w_tp]=x-x/10*10, x/=10; while(x);
    while(putchar(writ[w_tp--]^48), w_tp);
}

# include 
using namespace std;
typedef pair  par;

const int maxn = 105;
const int mod = 1e9+7;

inline int inv(int x,int y=mod-2,int r=1) {
	for(; y; y>>=1, x=1ll*x*x%mod)
		if(y&1) r=1ll*r*x%mod;
	return r;
}
inline void dec(int& x,int y) { x=(x-y<0?x-y+mod:x-y); }
inline void inc(int& x,int y) { x=(x+y>=mod?x+y-mod:x+y); }
inline int _dec(int x,int y) { return x-y<0?x-y+mod:x-y; }

par val[maxn];
bool arr[maxn][maxn];
int n,a[maxn][maxn];

inline int gauss(int n) {
	int ret=1, j, Inv, tmp; bool f=0;
	for(int i=1;i<=n;++i) {
		for(j=i; j<=n && !a[j][i]; ++j);
		if(i^j) swap(a[i],a[j]), f^=1;
		Inv = inv(a[i][i]);
		ret = 1ll*ret*a[i][i]%mod;
		for(j=i+1;j<=n;++j) if(a[j][i]) {
			tmp = 1ll*Inv*a[j][i]%mod;
			for(int k=i;k<=n;++k)
				dec(a[j][k],1ll*tmp*a[i][k]%mod);
		}
	}
	return f? mod-ret: ret;
}

inline void initialize(int x) {
	for(int i=1;i<=n;++i) {
		int tmp=0;
		for(int j=1;j<=n;++j) if(i^j) {
			if(arr[i][j]) a[i][j] = mod-x;
			else a[i][j] = mod-1;
			tmp += mod-a[i][j];
		}
		a[i][i] = tmp;
	}
}

int main() {
	n=read(9);
	for(int i=1;i

Method 2:容斥

# include 
# include 
# define print(x,y) write(x), putchar(y)

template 
inline T read(const T sample) {
    T x=0; char s; bool f=0;
    while(!isdigit(s=getchar())) f|=(s=='-');
    for(; isdigit(s); s=getchar()) x=(x<<1)+(x<<3)+(s^48);
    return f? -x: x;
}
template 
inline void write(T x) {
    static int writ[50], w_tp=0;
    if(x<0) putchar('-'), x=-x;
    do writ[++w_tp]=x-x/10*10, x/=10; while(x);
    while(putchar(writ[w_tp--]^48), w_tp);
}

# include 
using namespace std;

const int maxn = 105;
const int mod = 1e9+7;

inline int inv(int x,int y=mod-2,int r=1) {
	for(; y; y>>=1, x=1ll*x*x%mod)
		if(y&1) r=1ll*r*x%mod;
	return r;
}
inline void inc(int& x,int y) { x=(x+y>=mod?x+y-mod:x+y); }
inline void dec(int& x,int y) { x=(x-y<0?x-y+mod:x-y); }
inline int _inc(int x,int y) { return x+y>=mod?x+y-mod:x+y; }

vector  e[maxn];
int fac[maxn],ifac[maxn],G[maxn];
int n,g[maxn][maxn][2],tmp[maxn][2],sz[maxn];

void dfs(int u,int fa) {
	g[u][1][1]=g[u][1][0]=1; sz[u]=1;
	for(const auto& v:e[u]) if(v^fa) {
		dfs(v,u); 
		for(int i=1;i<=sz[u]+sz[v];++i) 
			tmp[i][0]=tmp[i][1]=0;
		for(int i=1;i<=sz[u];++i)
			for(int j=1;j<=sz[v];++j) {
				inc(tmp[i+j][0],1ll*g[u][i][0]*g[v][j][1]%mod);
				inc(tmp[i+j][1],1ll*g[u][i][1]*g[v][j][1]%mod);
				inc(tmp[i+j-1][0],1ll*g[u][i][0]*g[v][j][0]%mod);
				inc(tmp[i+j-1][1],(1ll*g[u][i][0]*g[v][j][1]+1ll*g[u][i][1]*g[v][j][0])%mod);
			}
		sz[u] += sz[v];
		for(int i=1;i<=sz[u];++i)
			for(int j=0;j<2;++j) g[u][i][j]=tmp[i][j];
	}
}

inline void initialize() {
	for(int i=fac[0]=1;i<=n;++i)
		fac[i] = 1ll*fac[i-1]*i%mod;
	ifac[n] = inv(fac[n]);
	for(int i=n-1;i>=0;--i)
		ifac[i] = 1ll*ifac[i+1]*(i+1)%mod;
}
inline int C(int n,int m) {
	return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}

int main() {
	n=read(9); initialize();
	for(int i=1;i

相关