禁止套娃: 线段树套宗法树


树套树

这是一种思想, 不是什么特定的数据结构, 不过实现起来一般都是从外层的树形数据结构的每个节点上, 挂一个内层树形数据结构的根的指针, 这样, 本来是要查询或修改外层节点的信息的行为, 变成了在外层某节点对应的内层数据结构上查询或修改某种信息.

容易发现, 这种思想非常占空间, 所以一般这种数据结构尤其是内层, 必须动态开点或直接实现可持久化.

模板

直接来看一道题来感受这种思想:

维护一个序列, 支持:

  • 区间查询第 \(k\) 小的数

  • 区间查排名

  • 单点修改

  • 区间查前驱

  • 区间查后继

官方做法是用线段树做外层数据结构, 每个节点挂一棵平衡树, 存储这个线段树节点表示的区间的数, 因为平衡树本来就是动态开点的, 所以空间是可以接受的, 加起来是 \(O(nlogn)\).

区间查排名, 只需要 \(O(logn)\) 找出这个区间对应的 \(O(logn)\) 棵平衡树, 然后对每个平衡树 \(O(logn)\) 查询排名, 总共 \(O(log^2n)\).

单点修改也很简单, \(O(logn)\) 找出包含这个单点的 \(O(logn)\) 棵平衡树, 每棵平衡树 \(O(logn)\) 删除, \(O(logn)\) 插入即可.

区间查前驱后继也很简单, 对 \(O(logn)\) 棵平衡树分别进行 \(1\) 次复杂度为 \(O(logn)\) 次查询前驱/后继, 取最大/最小的那个, 总时间 \(O(log^2n)\).

对于区间查询第 \(k\) 小的数, 因为是在 \(logn\) 棵树上查询, 所以不能像普通平衡树一样, 二分查找这个位置, \(O(logn)\) 查询. 只能二分答案, \(O(logn)\) 的二分答案, \(O(logn)\) 的线段树上查询, \(O(logn)\) 的平衡树上查询, 总复杂度 \(O(log^3n)\).

所以对于 \(m\) 次操作的总复杂度应该是 \(O(mlog^3n)\).

代码

首先是平衡树部分, 这里采用了宗法树.

struct SubNode {
  SubNode *LS, *RS;
  int Val, Ival;
  unsigned Size;
} SN[2000005], *CntSN(SN);

旋转

void Rotate(SubNode *x) {
  x->Size = x->LS->Size + x->RS->Size;
  x->Ival = x->LS->Ival, x->Val = x->RS->Val;
  if(x->LS->Size * 3 < x->RS->Size) {
    register SubNode *Move(x->RS);
    x->RS = Move->RS;
    Move->RS = Move->LS;
    Move->LS = x->LS;
    x->LS = Move;
    Move->Ival = Move->LS->Ival, Move->Val = Move->RS->Val, Move->Size = Move->LS->Size + Move->RS->Size;
    return;
  }
  if(x->RS->Size * 3 < x->LS->Size) {
    register SubNode *Move(x->LS);
    x->LS = Move->LS;
    Move->LS = Move->RS;
    Move->RS = x->RS;
    x->RS = Move;
    Move->Ival = Move->LS->Ival, Move->Val = Move->RS->Val, Move->Size = Move->LS->Size + Move->RS->Size; 
  }
}

插入

SubNode *Insert(SubNode *x) {
  if(x->Size == 1) {
    SubNode *Fa(++CntSN);
    if(x->Val < OpVal) {
      Fa->LS = x;
      Fa->RS = ++CntSN;
    } else {
      Fa->RS = x;
      Fa->LS = ++CntSN;
    }
    CntSN->Val = CntSN->Ival = OpVal, CntSN->Size = 1;
    Fa->Size = 2;
    Fa->Ival = Fa->LS->Ival, Fa->Val = Fa->RS->Val;
    return Fa;
  }
  if(x->LS->Val < OpVal) x->RS = Insert(x->RS);
  else x->LS = Insert(x->LS);
  Rotate(x);
  return x;
}

删除

SubNode *Delete(SubNode *x) {
  if(x->LS->Val < OpTmp) {
    if(x->RS->Size == 1) {
      if(x->RS->Val == OpTmp) {
        return x->LS;
      } else {
        return x;
      }
    }
    x->RS = Delete(x->RS);
  } else {
    if(x->LS->Size == 1) {
      if(x->LS->Val == OpTmp) {
        return x->RS;
      } else {
        return x;
      }
    }
    x->LS = Delete(x->LS);
  }
  Rotate(x);
  return x; 
}

查排名

void SubRank(SubNode *x) {
  if (x->Size == 1) {
    if (x->Val < OpVal) ++Ans;
    return;
  }
  if (x->LS->Val < OpVal)
    Ans += x->LS->Size, SubRank(x->RS);
  else
    SubRank(x->LS);
}

查前驱

void SubPre(SubNode *x) {
  if (x->Size == 1) {
    if (x->Val < OpVal) Ans = max(Ans, x->Val);
    return;
  }
  if (x->RS->Ival >= OpVal) {
    SubPre(x->LS);
  } else {
    SubPre(x->RS);
  }
}

查后继

void SubSuc(SubNode *x) {
  if (x->Size == 1) {
    if (x->Val > OpVal) Ans = min(Ans, x->Val);
    return;
  }
  if (x->LS->Val < OpVal) {
    SubSuc(x->RS);
  } else {
    SubSuc(x->LS);
  }
}

然后是线段树部分, 最普通的线段树即可.

struct Node {
  Node *LS, *RS;
  SubNode *Root;
} N[100005], *CntN(N);

建树

void Build(Node *x, unsigned L, unsigned R) {
  x->Root = ++CntSN;
  x->Root->Val = x->Root->Ival = a[L];
  x->Root->Size = 1;
  if (L == R) return;
  for (register unsigned i(L + 1); i <= R; ++i) {
    OpVal = a[i], x->Root = Insert(x->Root);
  }
  register unsigned Mid((L + R) >> 1);
  Build(x->LS = ++CntN, L, Mid);
  Build(x->RS = ++CntN, Mid + 1, R);
}

单点修改

就是先插入一个新值, 然后将原值删除.

void Change(Node *x, unsigned L, unsigned R) {
  x->Root = Insert(x->Root);
  x->Root = Delete(x->Root);
  if (L == R) {
    return;
  }
  register unsigned Mid((L + R) >> 1);
  if (Mid < OpL) {
    Change(x->RS, Mid + 1, R);
  } else {
    Change(x->LS, L, Mid);
  }
  return;
}

查排名

void Rank(Node *x, unsigned L, unsigned R) {
  if (L >= OpL && R <= OpR) {
    SubRank(x->Root);
    return;
  }
  register unsigned Mid((L + R) >> 1);
  if (OpL <= Mid) {
    Rank(x->LS, L, Mid);
  }
  if (Mid < OpR) {
    Rank(x->RS, Mid + 1, R);
  }
  return;
}

查前驱

void Pre(Node *x, unsigned L, unsigned R) {
  if (L >= OpL && R <= OpR) {
    return SubPre(x->Root);
  }
  register unsigned Mid((L + R) >> 1);
  if (OpL <= Mid) {
    Pre(x->LS, L, Mid);
  }
  if (Mid < OpR) {
    Pre(x->RS, Mid + 1, R);
  }
}

查后继

void Suc(Node *x, unsigned L, unsigned R) {
  if (L >= OpL && R <= OpR) {
    return SubSuc(x->Root);
  }
  register unsigned Mid((L + R) >> 1);
  if (OpL <= Mid) {
    Suc(x->LS, L, Mid);
  }
  if (Mid < OpR) {
    Suc(x->RS, Mid + 1, R);
  }
}

二分答案

void Find() {
  register int L(0), R(100000000), Mid;
  while (L < R) {
    Mid = ((L + R + 1) >> 1);
    Ans = 1, OpVal = Mid, Rank(N, 1, n);
    if (Ans > OpTmp)
      R = Mid - 1;
    else
      L = Mid;
  }
  Ans = L;
}

主函数

unsigned a[50005], m, n, Cnt(0), OpL, OpR, A, B, C, D, t, Tmp(0);
int Ans, OpVal, OpTmp;
int main() {
  n = RD(), m = RD();
  for (register unsigned i(1); i <= n; ++i) a[i] = RD();
  Build(N, 1, n);
  for (register unsigned i(1); i <= m; ++i) {
    A = RD();
    switch (A) {
      case 1: {
        OpL = RD(), OpR = RD(), OpVal = RD();
        Ans = 1, Rank(N, 1, n);
        break;
      }
      case 2: {
        OpL = RD(), OpR = RD(), OpTmp = RD();
        Find();
        break;
      }
      case 3: {
        OpL = RD();
        OpTmp = a[OpL];
        a[OpL] = OpVal = RD();
        Change(N, 1, n);
        break;
      }
      case 4: {
        OpL = RD(), OpR = RD(), OpVal = RD();
        Ans = -2147483647, Pre(N, 1, n);
        break;
      }
      case 5: {
        OpL = RD(), OpR = RD(), OpVal = RD();
        Ans = 2147483647, Suc(N, 1, n);
        break;
      }
    }
    if (A != 3) {
      printf("%d\n", Ans);
    }
  }
  return Wild_Donkey;
}