leetcode1569 将子数组重新排序得到同一个二叉查找树的方案数
思路:
对于给定bst,需要保证根节点出现在最前面,而左子树节点和右子树节点的相对位置可以进行排列组合。左子树和右子树内部需要递归地满足同样的条件。可以通过后续遍历逐层计算,最后将可能的方案数乘起来。学习了一种利用并查集优化时间复杂度的方法:https://leetcode-cn.com/problems/number-of-ways-to-reorder-array-to-get-same-bst/solution/jiang-zi-shu-zu-zhong-xin-pai-xu-de-dao-tong-yi-2/。
实现:
1 class DSU 2 { 3 int _n; 4 vector<int> p; 5 public: 6 DSU(int n):_n(n + 1) 7 { 8 p.resize(_n); 9 for (int i = 1; i < _n; i++) p[i] = i; 10 } 11 int find(int x) 12 { 13 if (p[x] == x) return x; 14 return p[x] = find(p[x]); 15 } 16 void uni(int x, int y) 17 { 18 x = find(x); y = find(y); 19 p[x] = y; 20 } 21 int get_root() 22 { 23 int res = 0; 24 for (int i = 1; i < _n; i++) 25 { 26 if (p[i] == i) { res = i; break; } 27 } 28 return res; 29 } 30 }; 31 class Comb 32 { 33 int _n, MOD; 34 vector<int> inv, fac, fac_inv; 35 public: 36 Comb(int n, int mod):_n(n + 1), MOD(mod) 37 { 38 inv.resize(_n); fac.resize(_n); fac_inv.resize(_n); 39 for (int i = 0; i < _n; i++) 40 { 41 if (i < 2) { fac[i] = fac_inv[i] = inv[i] = 1; continue; } 42 inv[i] = ((long long)MOD - MOD / i) * inv[MOD % i] % MOD; 43 fac[i] = (long long)fac[i - 1] * i % MOD; 44 fac_inv[i] = (long long)fac_inv[i - 1] * inv[i] % MOD; 45 } 46 } 47 int C(int i, int j) 48 { 49 return (long long)fac[i] * fac_inv[j] % MOD * fac_inv[i - j] % MOD; 50 } 51 }; 52 class Solution 53 { 54 int MOD = 1e9 + 7; 55 public: 56 int dfs(int root, vector<int>& left, vector<int>& right, vector<int>& siz, Comb& c) 57 { 58 if (root == 0) return 1; 59 int l = dfs(left[root], left, right, siz, c); 60 int r = dfs(right[root], left, right, siz, c); 61 int lc = siz[left[root]], rc = siz[right[root]]; 62 return (long long)l * r % MOD * c.C(lc + rc, lc) % MOD; 63 } 64 int numOfWays(vector<int>& nums) 65 { 66 int n = nums.size(); 67 DSU d(n); Comb c(n, MOD); 68 vector<bool> vis(n + 1, false); 69 vector<int> left(n + 1, 0), right(n + 1, 0), siz(n + 1, 1); 70 siz[0] = 0; 71 for (int i = n - 1; i >= 0; i--) 72 { 73 if (nums[i] + 1 <= n && vis[nums[i] + 1]) 74 { 75 int x = d.find(nums[i] + 1); 76 d.uni(x, nums[i]); //须将x合并到nums[i] 77 siz[nums[i]] += siz[x]; 78 right[nums[i]] = x; 79 } 80 if (nums[i] - 1 >= 1 && vis[nums[i] - 1]) 81 { 82 int x = d.find(nums[i] - 1); 83 d.uni(x, nums[i]); //须将x合并到nums[i] 84 siz[nums[i]] += siz[x]; 85 left[nums[i]] = x; 86 } 87 vis[nums[i]] = true; 88 } 89 int root = d.get_root(); 90 return max(dfs(root, left, right, siz, c) - 1, 0); 91 } 92 };