Topcoder 10055 CactusAutomorphisms


题目链接

Topcoder 10055 CactusAutomorphisms

题目大意

给你一个 \(n\) 个点的仙人掌,求它的自同构数。

\(1\leq n\leq 200\)

思路

自同构可以直观理解为仙人掌外面有一圈轮廓,我们把里面的仙人掌拿下来,换个姿势再塞到轮廓里面。

自同构类型的问题在树型结构上会比较好解决,于是考虑对仙人掌建出圆方树,即树上原来的点是圆点,对于每个环建一个方点,环上的点与新点连边,然后把原来环上的边删去,不过本来其它的边中间也是要建方点的,在这里没必要建。

一棵树可以选择很多根,但是至多只有 \(2\) 个重心,于是考虑把重心做根,这样两棵同构的树的根就不会发生变化,适合子树递归解决问题,当树有 \(2\) 个重心时(它们一定是相邻的),在两点间新建一个圆点,这样重心就变成中间的圆点了。

接下来子树递归解决问题,当前点是圆点时,其儿子可以任意排序,于是我们把同构的子树放一组,子树的答案之积乘以每一组大小的阶乘 即为当前子树的答案。当前点是方点时,由于方点的儿子曾经是在环上的,是有序的,所以只有正反方向放两种方案,要判断两者是否同构,另外如果目前是在根处,没有父亲,则还可以旋转环,\(O(n)\) 旋转一圈检查是否同构即可。

对于判断两个有根树是否同构,可以使用树哈希,考虑树 \(dfs\) 的过程,一个点入栈记为 \(1\),出栈记为 \(0\),则每棵不同构的树都对应了唯一的一个 \(01\) 序列,作为二进制取模存起来即可。注意到这里做的是圆方树,圆点和方点有别,所以方点入栈另记为 \(2\),圆方点出栈都为 \(0\),三进制储存。

在计算当前点所辖子树的哈希值时,圆点的子树无序,可以将它们的哈希值从小到大排序后拼接起来,然后开头加入一个 \(1\),尾部加入一个 \(0\),方点的子树可以正反放,可以分别计算两种方向的哈希值,取较小的,开头加 \(2\) 尾部加 \(0\) 即可。

时间复杂度:\(O(n\log n)\)

实现细节

  • 当前点是方点时,儿子们正反放是相当于环的翻转的,所以必须连续,如果父亲在邻接表的中间,需要将父亲前面的节点接到后面节点的尾部,这样顺序才是正确的。
  • 细节比较多,但好像没什么别的标志性的了。吐槽一句这个题一点都不像 Topcoder 的风格。

Code

边拍边调写出来的,有点丑。

#include
#include
#include
#include
#include
#include
#define mem(a,b) memset(a, b, sizeof(a))
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 820
#define ll long long
#define mod 1000000003
using namespace std;

class CactusAutomorphisms{
    public:
    int n;
    int head[N], to[4*N], nxt[4*N];
    int low[N], dfn[N], c[N], siz[N];
    int cnt, scc, num;
    bool square[N];
    stack s;
    vector center;

    int dc_u, dc_v; // double centers
    ll hash[N], ans[N], pre[N], suf[N];
    ll fact[N], pow[N];

    void init(){ mem(head, -1), cnt = -1; }
    void add_e(int a, int b, bool id){
        nxt[++cnt] = head[a], head[a] = cnt, to[cnt] = b;
        if(id) add_e(b, a, 0);
    }

    void tarjan(int x, int fa){
        dfn[x] = low[x] = ++num;
        s.push(x);
        for(int i = head[x]; ~i; i = nxt[i]){
            int y = to[i];
            if(!dfn[y]){
                tarjan(y, x);
                low[x] = min(low[x], low[y]);
            } else if(y != fa) low[x] = min(low[x], dfn[y]);
        }
        if(low[x] == dfn[x]){
            int y; scc++;
            bool circ = s.top() != x;
            if(circ) n++, square[n] = true;
            do{
                y = s.top(); s.pop();
                c[y] = scc;
                if(circ) add_e(y, n, 1);
            } while(y != x);
        }
    }

    void dfs(int x, int fa){
        siz[x] = 1;
        int mx = 0;
        for(int i = head[x]; ~i; i = nxt[i]){
            int y = to[i];
            if(y == fa || (c[y] == c[x])) continue;
            dfs(y, x);
            siz[x] += siz[y], mx = max(mx, siz[y]);
        }
        mx = max(mx, n-siz[x]);
        if(mx <= n/2) center.push_back(x);
    } 

    int count(vector subt, ll standard){
        int cnt = 0, m = subt.size();
        pre[0] = hash[subt[0]];
        rep(i,1,m-1) pre[i] = (pre[i-1] * pow[2*siz[subt[i]]] + hash[subt[i]]) % mod;
        int tot = 0;
        per(i,m-1,0) suf[i] = (hash[subt[i]] * pow[tot] + suf[i+1]) % mod, tot += 2*siz[subt[i]];
        tot = 2*n-2;
        cnt += pre[m-1] == standard;
        per(i,m-1,1){
            tot -= 2*siz[subt[i]];
            ll val = (suf[i] * pow[tot] + pre[i-1]) % mod;
            if(val == standard) cnt++;
        }
        return cnt;
    }

    void solve_for_square_root(int x, vector subt){
        int cnt = 1, m = subt.size();
        pre[0] = hash[subt[0]];
        rep(i,1,m-1) pre[i] = (pre[i-1] * pow[2*siz[subt[i]]] + hash[subt[i]]) % mod;
        ll standard = pre[m-1];

        int t = count(subt, standard);
        reverse(subt.begin(), subt.end()), t += count(subt, standard);
        (ans[x] *= t) %= mod;
    }

    bool solve(int x, int fa){
        vector subt, bef;
        bool flag = true;
        for(int i = head[x]; ~i; i = nxt[i]){
            int y = to[i];
            if(c[y] == c[x]) continue;
            if(x == dc_u || x == dc_v){
                if(y == fa) continue;
                if(x+y == dc_u+dc_v){ flag = false; continue; }
            } else if(y == fa){ flag = false; continue;}
            if(flag) bef.push_back(y);
            else subt.push_back(y);
        }
        for(int y : bef) subt.push_back(y);
        bef.clear();
        siz[x] = ans[x] = 1;
        for(int y : subt) if(solve(y, x)){
            bef.push_back(y);
            (ans[x] *= ans[y]) %= mod;
            siz[x] += siz[y];
        }
        subt = bef;

        if(!square[x]){
            sort(subt.begin(), subt.end(), [&] (int a, int b){ return hash[a] < hash[b]; });
            hash[x] = 1;
        } else hash[x] = 2;
        for(int y : subt) hash[x] = (hash[x] * pow[siz[y]*2] + hash[y]) % mod;
        (hash[x] *= 3) %= mod;

        if(!square[x]){
            rep(i,0,(int)subt.size()-1){
                int cnt = 0;
                ll val = hash[subt[i]];
                while(i < subt.size() && hash[subt[i]] == val) i++, cnt++; i--;
                (ans[x] *= fact[cnt]) %= mod;
            }
        } else if(fa){
            vector rev = subt;
            reverse(rev.begin(), rev.end());
            bool flag = true;
            ll tmp = 2;
            rep(i,0,(int)rev.size()-1){
                flag &= (hash[rev[i]] == hash[subt[i]]);
                tmp = (tmp * pow[siz[rev[i]]*2] + hash[rev[i]]) % mod;
            }
            (tmp *= 3) %= mod, hash[x] = min(hash[x], tmp);
            if(flag) (ans[x] *= 2) %= mod;

        } else solve_for_square_root(x, subt);
        return true;
    }

    int trans(string s){
        int num = 0;
        for(char c : s) num = num*10 + c-'0';
        return num;
    }
    void get_edges(string s){
        string tmp;
        init();
        rep(i,0,(int)s.size()-1){
            tmp = "";
            while(i < s.size() && s[i] != ',') tmp += s[i++];
            int j = 0;
            while(tmp[j] != ' ') j++;
            int u = trans(tmp.substr(0, j)), v = trans(tmp.substr(j+1, tmp.size()-j-1));
            add_e(u, v, 1);
        }
    }

    int count(int n, vector edges){
        this->n = n;
        string s = "";
        for(string t : edges) s += t;
        get_edges(s);

        tarjan(1, 0);
        dfs(1, 0);
        int root;
        if(center.size() > 1){
            dc_u = center[0], dc_v = center[1];
            add_e(dc_u, ++this->n, 1), add_e(this->n, dc_v, 1);
            root = this->n;
            c[root] = -1;
        } else root = center[0];

        pow[0] = fact[0] = 1;
        rep(i,1,2*this->n) pow[i] = (pow[i-1] * 3) % mod, fact[i] = (fact[i-1] * i) % mod;

        solve(root, 0);
        return (int)ans[root];
    }
} solve;

int main(){
    int n, m;
    string s;
    vector edges;
    cin>>n>>m;
    getline(cin, s);
    rep(i,1,m) getline(cin, s), edges.push_back(s);
    cout<< solve.count(n, edges) <