ABC201E (位运算)


题意描述

有一颗\(N\)个节点的加权树,定义\(dist(x,y)\)为从\(x\)\(y\)的路径上权值的异或,对于所有的\((i,j)\),求出所有\(dist(i,j)\)之和,结果对\(10^9+7\)取模。\((1 \leq i < j \leq N)\)

思路

将这棵树变为以任意一个点\(x\)为根的有根树。
对于树上的任意两点\(i\)\(j\)以及它们的最近公共祖先\(k\),都有\(dist(i,j)=dist(k,i) \oplus dist(k,j)\),接着对这个公式进行变形:

\[dist(i,j) = dist(k,i) \oplus dist(k,j) \oplus dist(x,k) \oplus dist(x,k) \\ dist(i,j) = (dist(x,k) \oplus dist(k,i)) \oplus (dist(x,k) \oplus dist(k,j)) \\ dist(i,j) = dist(x,i) \oplus dist(x,j) \]

则问题变成了对于所有的\((i,j)\),求出所有的\(dist(x,i) \oplus dist(x,j)\)
我们使用\(d[i]\)来表示从节点\(i\)到树根\(x\)路径上的异或值。根据异或的性质,当某一位同为\(1\)时,异或起来为\(0\),否则为\(1\)
可以使用\(bfs\)求出每个节点到根节点的路径值。
\(cnt_{0}\)表示该二进制位下\(0\)的个数,\(cnt_{1}\)表示为该二进制位下\(1\)的个数。接着可以对每一位进行枚举,对于每一位,它的贡献为\(2^i * cnt_{0} *cnt_{1}\),将所有位的贡献累加起来就是答案。

AC代码

#include 


using namespace std;


#define fi first
#define se second
#define PB push_back
#define EB emplace_back
#define mst(x,a) memset(x,a,sizeof(x))
#define all(a) a.begin(),a.end()
#define rep(x,l,u) for(int x=l;x=u;x--)
#define sz(x) (int)x.size()
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr);
#define seteps(N) setprecision(N)
#define uni(x) sort(all(x)), x.erase(unique(all(x)), x.end())
#define lson (ind<<1)
#define rson (ind<<1|1)
#define endl '\n'
#define dbg(x) cerr << #x " = " << (x) << endl
#define mp make_pair
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"< PII;
typedef pair PCC;
typedef pair PDD;
typedef pair PLL;
typedef pair PIII;
typedef pair PLB;


struct Scanner {

    bool hasNext = 1;
    bool hasRead = 1;

    int nextInt() {
        hasRead = 0;
        int res = 0;
        char flag = 1, ch = getchar();

        while (ch != EOF && !isdigit(ch)) {
            hasRead = 1;
            flag = (ch == '-') ? -flag : flag;
            ch = getchar();
        }

        while (ch != EOF && isdigit(ch)) {
            hasRead = 1;
            res = res * 10 + (ch - '0');
            ch = getchar();
        }

        if (ch == EOF)
            hasNext = 0;

        return res * flag;
    }

    ll nextLL() {
        hasRead = 0;
        ll res = 0;
        char flag = 1, ch = getchar();

        while (ch != EOF && !isdigit(ch)) {
            hasRead = 1;
            flag = (ch == '-') ? -flag : flag;
            ch = getchar();
        }

        while (ch != EOF && isdigit(ch)) {
            hasRead = 1;
            res = res * 10 + (ch - '0');
            ch = getchar();
        }

        if (ch == EOF)
            hasNext = 0;

        return res * flag;
    }

    char nextChar() {
        hasRead = 0;
        char ch = getchar();

        while (ch != EOF && isspace(ch)) {
            hasRead = 1;
            ch = getchar();
        }

        if (ch == EOF)
            hasNext = 0;

        return ch;
    }

    int nextString(char *str) {
        hasRead = 0;
        int len = 0;
        char ch = getchar();

        while (ch != EOF && isspace(ch)) {
            hasRead = 1;
            ch = getchar();
        }

        while (ch != EOF && !isspace(ch)) {
            hasRead = 1;
            str[++len] = ch;
            ch = getchar();
        }

        str[len + 1] = 0;

        if (ch == EOF)
            hasNext = 0;

        return len;
    }

} sc;

void rd(int &x) {
    x = sc.nextInt();
}

void rd(ll &x) {
    x = sc.nextLL();
}

void rd(char &x) {
    x = sc.nextChar();
}

void rd(char *x) {
    sc.nextString(x);
}

template
void rd(pair &x) {
    rd(x.first);
    rd(x.second);
}

template
void rd(T *x, int n) {
    for (int i = 1; i <= n; ++i)
        rd(x[i]);
}


struct Printer {

    void printInt(int x) {
        if (x < 0) {
            putchar('-');
            x = -x;
        }

        if (x >= 10)
            printInt(x / 10);

        putchar('0' + x % 10);
    }

    void printLL(ll x) {
        if (x < 0) {
            putchar('-');
            x = -x;
        }

        if (x >= 10)
            printLL(x / 10);

        putchar('0' + x % 10);
    }

} printer;

void pr(int x, char ch = '\n') {
    printer.printInt(x);
    putchar(ch);
}

void pr(ll x, char ch = '\n') {
    printer.printLL(x);
    putchar(ch);
}

template
void pr(pair x, char ch = '\n') {
#ifdef LOCAL
    putchar('<');
    pr(x.first, ' ');
    pr(x.second, '>');
    putchar(ch);
    return;
#endif // LOCAL
    pr(x.first, ' ');
    pr(x.second, ch);
}

template
void pr(T *x, int n) {
    for (int i = 1; i <= n; ++i)
        pr(x[i], " \n"[i == n]);
}

template
void pr(vector &x) {
    int n = x.size();

    for (int i = 1; i <= n; ++i)
        pr(x[i - 1], " \n"[i == n]);
}



const int N = 2 * 1e5 + 5;
const int M = 5005*4;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const lll oone = 1;
const double eps = 1e-4;
const double pi = acos(-1);

struct NODE
{
    int v;
    ll w;
    NODE(int v,ll w): v(v) , w(w) {}
};

vector g[N];
int n;
ll d[N];
bool st[N];

struct Solver {

    void InitOnce() {

    }

    void Read() {
        rd(n);
        rep(i,0,n-1){
            int u,v;
            ll w;
            rd(u);rd(v);rd(w);
            g[u].EB(v,w);
            g[v].EB(u,w);
        }
    }


    void Solve() {
        queue q;
        q.push(1);
        mst(d,-1);
        d[1] = 0;
        while(sz(q)){
            int t = q.front();
            q.pop();
            for(auto e : g[t]){
                int v = e.v;
                ll w = e.w;
                if(d[v] == -1){
                    d[v] = d[t] ^ w;
                    q.push(v);
                }
            }
        }



        ll ans = 0;
        rep(i,0,61){
            vector cnt(2,0);
            rep(j,1,n+1) cnt[d[j]>>i&1]++;
            ans = (ans + (1ll << i) % mod * cnt[0] % mod * cnt[1] % mod) % mod;
        }
        pr(ans % mod);
    }


} solver;
int main() {
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif
    solver.InitOnce();
    int t = 1;
    //rd(t);
    rep(i, 1, t + 1) {
        solver.Read();

        if (!sc.hasRead)
            break;

        solver.Solve();

        if (!sc.hasNext)
            break;
    }
    return 0;
}