ABC201E (位运算)
题意描述
有一颗\(N\)个节点的加权树,定义\(dist(x,y)\)为从\(x\)到\(y\)的路径上权值的异或,对于所有的\((i,j)\),求出所有\(dist(i,j)\)之和,结果对\(10^9+7\)取模。\((1 \leq i < j \leq N)\)。
思路
将这棵树变为以任意一个点\(x\)为根的有根树。
对于树上的任意两点\(i\)和\(j\)以及它们的最近公共祖先\(k\),都有\(dist(i,j)=dist(k,i) \oplus dist(k,j)\),接着对这个公式进行变形:
则问题变成了对于所有的\((i,j)\),求出所有的\(dist(x,i) \oplus dist(x,j)\)
我们使用\(d[i]\)来表示从节点\(i\)到树根\(x\)路径上的异或值。根据异或的性质,当某一位同为\(1\)时,异或起来为\(0\),否则为\(1\)。
可以使用\(bfs\)求出每个节点到根节点的路径值。
令\(cnt_{0}\)表示该二进制位下\(0\)的个数,\(cnt_{1}\)表示为该二进制位下\(1\)的个数。接着可以对每一位进行枚举,对于每一位,它的贡献为\(2^i * cnt_{0} *cnt_{1}\),将所有位的贡献累加起来就是答案。
AC代码
#include
using namespace std;
#define fi first
#define se second
#define PB push_back
#define EB emplace_back
#define mst(x,a) memset(x,a,sizeof(x))
#define all(a) a.begin(),a.end()
#define rep(x,l,u) for(int x=l;x=u;x--)
#define sz(x) (int)x.size()
#define IOS ios::sync_with_stdio(false);cin.tie(nullptr);
#define seteps(N) setprecision(N)
#define uni(x) sort(all(x)), x.erase(unique(all(x)), x.end())
#define lson (ind<<1)
#define rson (ind<<1|1)
#define endl '\n'
#define dbg(x) cerr << #x " = " << (x) << endl
#define mp make_pair
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"< PII;
typedef pair PCC;
typedef pair PDD;
typedef pair PLL;
typedef pair PIII;
typedef pair PLB;
struct Scanner {
bool hasNext = 1;
bool hasRead = 1;
int nextInt() {
hasRead = 0;
int res = 0;
char flag = 1, ch = getchar();
while (ch != EOF && !isdigit(ch)) {
hasRead = 1;
flag = (ch == '-') ? -flag : flag;
ch = getchar();
}
while (ch != EOF && isdigit(ch)) {
hasRead = 1;
res = res * 10 + (ch - '0');
ch = getchar();
}
if (ch == EOF)
hasNext = 0;
return res * flag;
}
ll nextLL() {
hasRead = 0;
ll res = 0;
char flag = 1, ch = getchar();
while (ch != EOF && !isdigit(ch)) {
hasRead = 1;
flag = (ch == '-') ? -flag : flag;
ch = getchar();
}
while (ch != EOF && isdigit(ch)) {
hasRead = 1;
res = res * 10 + (ch - '0');
ch = getchar();
}
if (ch == EOF)
hasNext = 0;
return res * flag;
}
char nextChar() {
hasRead = 0;
char ch = getchar();
while (ch != EOF && isspace(ch)) {
hasRead = 1;
ch = getchar();
}
if (ch == EOF)
hasNext = 0;
return ch;
}
int nextString(char *str) {
hasRead = 0;
int len = 0;
char ch = getchar();
while (ch != EOF && isspace(ch)) {
hasRead = 1;
ch = getchar();
}
while (ch != EOF && !isspace(ch)) {
hasRead = 1;
str[++len] = ch;
ch = getchar();
}
str[len + 1] = 0;
if (ch == EOF)
hasNext = 0;
return len;
}
} sc;
void rd(int &x) {
x = sc.nextInt();
}
void rd(ll &x) {
x = sc.nextLL();
}
void rd(char &x) {
x = sc.nextChar();
}
void rd(char *x) {
sc.nextString(x);
}
template
void rd(pair &x) {
rd(x.first);
rd(x.second);
}
template
void rd(T *x, int n) {
for (int i = 1; i <= n; ++i)
rd(x[i]);
}
struct Printer {
void printInt(int x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x >= 10)
printInt(x / 10);
putchar('0' + x % 10);
}
void printLL(ll x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x >= 10)
printLL(x / 10);
putchar('0' + x % 10);
}
} printer;
void pr(int x, char ch = '\n') {
printer.printInt(x);
putchar(ch);
}
void pr(ll x, char ch = '\n') {
printer.printLL(x);
putchar(ch);
}
template
void pr(pair x, char ch = '\n') {
#ifdef LOCAL
putchar('<');
pr(x.first, ' ');
pr(x.second, '>');
putchar(ch);
return;
#endif // LOCAL
pr(x.first, ' ');
pr(x.second, ch);
}
template
void pr(T *x, int n) {
for (int i = 1; i <= n; ++i)
pr(x[i], " \n"[i == n]);
}
template
void pr(vector &x) {
int n = x.size();
for (int i = 1; i <= n; ++i)
pr(x[i - 1], " \n"[i == n]);
}
const int N = 2 * 1e5 + 5;
const int M = 5005*4;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const lll oone = 1;
const double eps = 1e-4;
const double pi = acos(-1);
struct NODE
{
int v;
ll w;
NODE(int v,ll w): v(v) , w(w) {}
};
vector g[N];
int n;
ll d[N];
bool st[N];
struct Solver {
void InitOnce() {
}
void Read() {
rd(n);
rep(i,0,n-1){
int u,v;
ll w;
rd(u);rd(v);rd(w);
g[u].EB(v,w);
g[v].EB(u,w);
}
}
void Solve() {
queue q;
q.push(1);
mst(d,-1);
d[1] = 0;
while(sz(q)){
int t = q.front();
q.pop();
for(auto e : g[t]){
int v = e.v;
ll w = e.w;
if(d[v] == -1){
d[v] = d[t] ^ w;
q.push(v);
}
}
}
ll ans = 0;
rep(i,0,61){
vector cnt(2,0);
rep(j,1,n+1) cnt[d[j]>>i&1]++;
ans = (ans + (1ll << i) % mod * cnt[0] % mod * cnt[1] % mod) % mod;
}
pr(ans % mod);
}
} solver;
int main() {
#ifdef LOCAL
freopen("data.in", "r", stdin);
freopen("data.out", "w", stdout);
#endif
solver.InitOnce();
int t = 1;
//rd(t);
rep(i, 1, t + 1) {
solver.Read();
if (!sc.hasRead)
break;
solver.Solve();
if (!sc.hasNext)
break;
}
return 0;
}