[ZJOI2016]小星星
题意
link
给出 \(n(\leq 17)\) 个点 \(m\) 条边的图,和 \(n\) 个点 \(n - 1\) 条边的树,问将树每个点对应到图中不同的点后,原树上的边在图中都存在的方案数。
树形dp + 一点状压
方程
对于树上的点可以设 \(f[u][t]\) 表示树的结点 \(u\) 对应到图上的 \(t\) 点,子树已经都对应了的方案数。
如果不考虑选择的点不同,就可能会把多个树上的点对应到同一个图上的点。
最终求出来的东西是: 最多对应图上 \(n\) 个点的方案数。
而我们要求的东西是恰好对应图上 \(n\) 个点的方案数。
不妨用 \(F(S)\) 表示最多的选出 \(|S|\) 个数的方案数,\(G(S)\) 表示恰好。
那么就有:
\[F(S) = \sum_{T \subseteq S} G(T) \]根据容斥原理, 有:
\[G(S) = \sum_{T \subseteq S} (-1) ^ {|S| - |T|} F(T) \]顺带一提:这里实际上是将子集和,变回原来一个集和的和,超集也有这样的操作。
所以只要限定图中选的点, 求出 \(2 ^ n\) 个不同的点集 \(S\) 的方案数,就能算出恰好的方案数。
注意到一个小问题,就是选出的点集 \(S\) 中可能存在多对一的情况,但是在最后的全集中一定是一一对应,不存在多对一情况。
初始状态
对于叶子节点 \(u\),它可以对应任意一个点集 \(S\) 中的点, \(f[u][i] = 1, (i \in S)\)。
转移
对于节点 \(u\), 假如它选择对应点 \(i (\in S)\), 那么就可以枚举孩子 \(v\) 的对应点 \(j\), 假如在图上存在 \((i, j)\) 这条边,那么:
\[f[u][i] = \prod_{v \in son}\sum_{j = 1}^{n} f[v][j] \]最后求得 \(\sum_{i} f[1][i]\) 就是满足限定点集 \(S\), 最多用 \(|S|\) 个图中点,将树对应到图上的方案数 \(F(S)\)。
分析
树形dp \(O(n ^ 2)\)。
枚举选定的点集 \(O(2^n)\)。
总的时间复杂度就是 \(O(n^22^n)\)。
代码
开了 O2 才过的屑代码
#include
using namespace std;
using ll = long long;
const int MAXN = 20;
const int INF = 0x7fffffff;
const int mod = 998244353;
template
void Read(T &x) {
x = 0; T f = 1; char a = getchar();
for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
x *= f;
}
int add(int a, int b) {
int c = a + b;
if (c >= mod) c -= mod;
if (c < 0) c += mod;
return c;
}
int mul(int a, int b) {
return 1ll * a * b % mod;
}
int qpow(int a, int b) {
int sum(1);
while(b) {
if (b & 1) sum = mul(sum, a);
a = mul(a, a);
b >>= 1;
}
return sum;
}
int n, m;
vector e[MAXN];
int p[MAXN][MAXN];
ll f[MAXN][MAXN];
void dp(int u, int fa, int lim) {
for (int i = 1; i <= n; i ++)
f[u][i] = 0;
for (auto v : e[u]) {
if (v == fa) continue;
dp(v, u, lim);
}
for (int i = 1; i <= n; i ++) {
if (!(lim & (1 << i - 1))) continue;
f[u][i] = 1;
for (auto v : e[u]) {
if (v == fa) continue;
ll sum = 0;
for (int j = 1; j <= n; j ++)
if (p[i][j])
sum += f[v][j];
f[u][i] *= sum;
}
}
}
int main() {
cin >> n >> m;
for (int i = 1; i <= m; i ++) {
int u, v;
cin >> u >> v;
p[u][v] = p[v][u] = 1;
}
for (int i = 1; i < n; i ++) {
int u, v;
cin >> u >> v;
e[u].emplace_back(v);
e[v].emplace_back(u);
}
ll ans = 0;
for (int i = 1; i < (1 << n); i ++) {
ll sum = 0;
dp(1, 0, i);
for (int j = 1; j <= n; j ++)
sum += f[1][j];
int cnt = 0;
for (int j = i; j; j ^= j & -j) cnt ++;
ans += (((n - cnt) & 1) ? -1 : 1) * sum;
}
cout << ans;
return 0;
}