CodeForces - 917D Stranger Trees
Description
\(\mathcal{P}\text{ortal.}\)
Solution
Method 1:矩阵树定理
麻了,之前还做过一道把边权写成多项式的矩阵树定理题,结果这题还是不会做 qwq.
设给定边的边权为 \(x\),其余边的边权为 \(1\),那么求出所有生成树的边权之积的和就是一个 \(n-1\) 次的多项式(考虑求行列式乘了 \(n-1\) 次),那么 \(x^k\) 项对应的系数就是有 \(k\) 条给定边的生成树个数。
如果直接用多项式进行高斯消元求行列式,复杂度则有 \(\mathcal O(n^4\log n)\)(最里面的循环有多项式乘法,所以带 \(\mathcal O(n\log n)\) 的复杂度),而且目测不会很好写。事实上可以拉格朗日插值,带入 \(x=1,2,\dots, n\) 计算 \(n\) 次行列式求得 \(n\) 个函数值,再将多项式依次降次,取 \(f(0)\) 作为系数,复杂度降到了 \(\mathcal O(n^4)\).
Method 2:容斥
\(\mathcal{C}\text{onclusion}\):一个 \(n\) 个点的无向图,若连通块个数为 \(k\),大小分别为 \(s_1,s_2,\dots,s_k\),则用 \(k-1\) 条边将 \(k\) 个连通块连起来的方案数为 \(n^{k-2}\cdot \prod_{i=1}^k s_i\).
证明就直接 \(\text{oi-wiki}\) 了。
设 \(g(i)\) 为 至少 有 \(i\) 条给定边的方案数,一个非常神奇的等价是这等价于在 给定树 上取 \(n-i\) 个连通块(也就是取 \(i\) 条给定边),其余边乱选的方案数!如果算出 \(g(i)\),那么代表恰好的 \(f(i)\) 用二项式反演也就可以算出来了。
令 \(g(i,j,k)\) 为以 \(i\) 为根的子树划分了 \(j\) 个连通块,当前点所在块大小为 \(k\)。由上文结论,我们只用维护 \(\prod_{i=1}^k s_i\)。那么 \(\mathcal O(n^3)\) 可以简单转移。
事实上还可以做到更优,类似 的转化,我们考虑 \(\prod_{i=1}^k s_i\) 的组合意义:在每个连通块中选一个点的方案数。设 \(g(i,j,0/1)\) 表示以 \(i\) 为根的子树划分了 \(j\) 个连通块,当前点所在块是否已选点的方案数,复杂度优化到 \(\mathcal O(n^2)\).
Code
Method 1:矩阵树定理
犯了很多神必错误,感觉自己像个脑瘫。
# include
# include
# define print(x,y) write(x), putchar(y)
template
inline T read(const T sample) {
T x=0; char s; bool f=0;
while(!isdigit(s=getchar())) f|=(s=='-');
for(; isdigit(s); s=getchar()) x=(x<<1)+(x<<3)+(s^48);
return f? -x: x;
}
template
inline void write(T x) {
static int writ[50], w_tp=0;
if(x<0) putchar('-'), x=-x;
do writ[++w_tp]=x-x/10*10, x/=10; while(x);
while(putchar(writ[w_tp--]^48), w_tp);
}
# include
using namespace std;
typedef pair par;
const int maxn = 105;
const int mod = 1e9+7;
inline int inv(int x,int y=mod-2,int r=1) {
for(; y; y>>=1, x=1ll*x*x%mod)
if(y&1) r=1ll*r*x%mod;
return r;
}
inline void dec(int& x,int y) { x=(x-y<0?x-y+mod:x-y); }
inline void inc(int& x,int y) { x=(x+y>=mod?x+y-mod:x+y); }
inline int _dec(int x,int y) { return x-y<0?x-y+mod:x-y; }
par val[maxn];
bool arr[maxn][maxn];
int n,a[maxn][maxn];
inline int gauss(int n) {
int ret=1, j, Inv, tmp; bool f=0;
for(int i=1;i<=n;++i) {
for(j=i; j<=n && !a[j][i]; ++j);
if(i^j) swap(a[i],a[j]), f^=1;
Inv = inv(a[i][i]);
ret = 1ll*ret*a[i][i]%mod;
for(j=i+1;j<=n;++j) if(a[j][i]) {
tmp = 1ll*Inv*a[j][i]%mod;
for(int k=i;k<=n;++k)
dec(a[j][k],1ll*tmp*a[i][k]%mod);
}
}
return f? mod-ret: ret;
}
inline void initialize(int x) {
for(int i=1;i<=n;++i) {
int tmp=0;
for(int j=1;j<=n;++j) if(i^j) {
if(arr[i][j]) a[i][j] = mod-x;
else a[i][j] = mod-1;
tmp += mod-a[i][j];
}
a[i][i] = tmp;
}
}
int main() {
n=read(9);
for(int i=1;i
Method 2:容斥
# include
# include
# define print(x,y) write(x), putchar(y)
template
inline T read(const T sample) {
T x=0; char s; bool f=0;
while(!isdigit(s=getchar())) f|=(s=='-');
for(; isdigit(s); s=getchar()) x=(x<<1)+(x<<3)+(s^48);
return f? -x: x;
}
template
inline void write(T x) {
static int writ[50], w_tp=0;
if(x<0) putchar('-'), x=-x;
do writ[++w_tp]=x-x/10*10, x/=10; while(x);
while(putchar(writ[w_tp--]^48), w_tp);
}
# include
using namespace std;
const int maxn = 105;
const int mod = 1e9+7;
inline int inv(int x,int y=mod-2,int r=1) {
for(; y; y>>=1, x=1ll*x*x%mod)
if(y&1) r=1ll*r*x%mod;
return r;
}
inline void inc(int& x,int y) { x=(x+y>=mod?x+y-mod:x+y); }
inline void dec(int& x,int y) { x=(x-y<0?x-y+mod:x-y); }
inline int _inc(int x,int y) { return x+y>=mod?x+y-mod:x+y; }
vector e[maxn];
int fac[maxn],ifac[maxn],G[maxn];
int n,g[maxn][maxn][2],tmp[maxn][2],sz[maxn];
void dfs(int u,int fa) {
g[u][1][1]=g[u][1][0]=1; sz[u]=1;
for(const auto& v:e[u]) if(v^fa) {
dfs(v,u);
for(int i=1;i<=sz[u]+sz[v];++i)
tmp[i][0]=tmp[i][1]=0;
for(int i=1;i<=sz[u];++i)
for(int j=1;j<=sz[v];++j) {
inc(tmp[i+j][0],1ll*g[u][i][0]*g[v][j][1]%mod);
inc(tmp[i+j][1],1ll*g[u][i][1]*g[v][j][1]%mod);
inc(tmp[i+j-1][0],1ll*g[u][i][0]*g[v][j][0]%mod);
inc(tmp[i+j-1][1],(1ll*g[u][i][0]*g[v][j][1]+1ll*g[u][i][1]*g[v][j][0])%mod);
}
sz[u] += sz[v];
for(int i=1;i<=sz[u];++i)
for(int j=0;j<2;++j) g[u][i][j]=tmp[i][j];
}
}
inline void initialize() {
for(int i=fac[0]=1;i<=n;++i)
fac[i] = 1ll*fac[i-1]*i%mod;
ifac[n] = inv(fac[n]);
for(int i=n-1;i>=0;--i)
ifac[i] = 1ll*ifac[i+1]*(i+1)%mod;
}
inline int C(int n,int m) {
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main() {
n=read(9); initialize();
for(int i=1;i