AVL树和红黑树01:平衡二叉树和AVL树


二分搜索树最大的问题是,在存储有序的数组时会退化成链表,性能大大降低。需要在现有二分搜索树的基础上添加一些机制,使得其能维持平衡二叉树

而AVL树是最早的自平衡二分搜索树,可以自动平衡的分布节点,避免退化成链表

平衡二叉树

完全二叉树、满二叉树都是平衡二叉树,其所有叶子节点的高度相差不超过1,平衡度很高

AVL树

而AVL树的要求相对较低,对于任意一个节点,只要左子树和右子树的高度差不超过1即可

在二分搜索树的基础上,记录下每个节点的高度,其值为左右子树高度的最大值加1,然后计算每个节点左右子树高度的差,称为平衡因子

实现AVL树的前置工作

新增getHeight()、getBalabceFactor()、isBST()和isBalanced()方法

添加元素后,在记录子树高度和平衡因子的同时,也要检查AVL树是否还维持着二分搜索树和平衡二叉树的性质

import java.util.ArrayList;

interface Map{

    void add(K key, V value);
    V remove(K key);
    boolean contains(K key);
    V get(K key);
    void set(K key, V value);
    int getSize();
    boolean isEmpty();
}

/**
 * 在二分搜索树的基础上添加记录子树的高度
 */
class AVLTree, V> implements Map{

    class Node{

        public K key;
        public V value;
        public Node leftNext;
        public Node rightNext;
        public int height;

        /**
         * 初始高度均为1,即自身;空节点高度则为0
         */
        public Node(K key, V value){

            this.key = key;
            this.value = value;
            leftNext = null;
            rightNext = null;
            height = 1;
        }
    }

    private Node root;
    private int size;

    public AVLTree(){

        root = null;
        size = 0;
    }

    @Override
    public int getSize(){

        return size;
    }

    @Override
    public boolean isEmpty(){

        return size == 0;
    }

    /**
     * 获取某节点子树的高度
     */
    private int getHeight(Node node){

        if (node == null){
            return 0;
        }
        else {
            return node.height;
        }
    }

    /**
     * 计算某节点的平衡因子
     */
    private int getBalanceFactor(Node node){

        if (node == null){
            return 0;
        }
        else {
            return getHeight(node.leftNext) - getHeight(node.rightNext);
        }
    }

    /**
     * 检查是否还是二分搜索树
     * 利用二分搜索树的中序遍历是有序的这一性质
     */
    private boolean isBST(Node root){

        ArrayList ls = new ArrayList<>();
        inOrder(root, ls);

        for (int i = 0; i < ls.size() - 1; i++) {

            if (ls.get(i).compareTo(ls.get(i + 1)) > 0){
                return false;
            }
        }

        return true;
    }

    private void inOrder(Node root, ArrayList ls){

        if (root == null){
            return;
        }

        inOrder(root.leftNext, ls);
        ls.add(root.key);
        inOrder(root.rightNext, ls);
    }

    /**
     * 检查是否还是平衡二叉树
     * 判断节点的平衡因子是否大于1
     */
    public boolean isBalanced(){

        return isBalanced(root);
    }

    private boolean isBalanced(Node root){

        if (root == null){
            return true;
        }

        int balanceFactor = getBalanceFactor(root);

        if (Math.abs(balanceFactor) > 1){
            return false;
        }
        else {
            return isBalanced(root.leftNext) && isBalanced(root.rightNext);
        }
    }

    private Node getNode(Node root, K key){

        if (root == null){
            return null;
        }

        if (root.key.compareTo(key) == 0){
            return root;
        }
        else if (root.key.compareTo(key) > 0){
            return getNode(root.leftNext, key);
        }
        else {
            return getNode(root.rightNext, key);
        }
    }

    @Override
    public boolean contains(K key){

        return getNode(root, key) != null;
    }

    @Override
    public V get(K key){

        Node res = getNode(root, key);

        return res == null ? null : res.value;
    }

    @Override
    public void add(K key, V value){

        root = add(root, key, value);
    }

    private Node add(Node root, K key, V value){

        if (root == null){

            size++;
            return new Node(key, value);
        }

        if (root.key.compareTo(key) > 0){
            root.leftNext = add(root.leftNext, key, value);
        }
        else if (root.key.compareTo(key) < 0){
            root.rightNext = add(root.rightNext, key, value);
        }
        else {
            root.value = value;
        }

        /**
         * 添加元素后计算高度值,值为子树的最大高度加1
         */
        root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

        /**
         * 计算平衡因子,如果大于1说明失衡了
         */
        int balanceFactor = getBalanceFactor(root);

        if (Math.abs(balanceFactor) > 1){
            System.out.println("平衡因子为:" + balanceFactor);
        }

        return root;
    }

    @Override
    public void set(K key, V value) {

        Node res = getNode(root, key);

        if (res == null){
            throw new IllegalArgumentException("键值不存在");
        }
        else {
            res.value = value;
        }
    }

    private Node Min(Node root){

        if (root.leftNext == null){
            return root;
        }

        return Min(root.leftNext);
    }

    private Node removeMin(Node root){

        if (root.leftNext == null){

            Node newRoot = root.rightNext;
            root.rightNext = null;
            size--;

            return newRoot;
        }

        root.leftNext = removeMin(root.leftNext);

        return root;
    }

    @Override
    public V remove(K key){

        Node res = getNode(root, key);

        if (res == null){
            return null;
        }

        root = remove(root, key);

        return res.value;
    }

    private Node remove(Node root, K key){

        if (root == null){
            return null;
        }

        if (root.key.compareTo(key) > 0){

            root.leftNext = remove(root.leftNext, key);
            return root;
        }
        else if (root.key.compareTo(key) < 0){

            root.rightNext = remove(root.rightNext, key);
            return root;
        }
        else {

            if (root.leftNext == null) {

                Node newRoot = root.rightNext;
                root.rightNext = null;
                size--;

                return newRoot;
            }
            else if (root.rightNext == null) {

                Node newRoot = root.leftNext;
                root.leftNext = null;
                size--;

                return newRoot;
            }
            else {

                Node min = Min(root.rightNext);
                min.rightNext = removeMin(root.rightNext);
                min.leftNext = root.leftNext;
                root.leftNext = null;
                root.rightNext = null;

                return min;
            }
        }
    }
}

AVL树如何实现自平衡

加入节点后,如果打破了平衡性,肯定是新增节点的祖辈节点的平衡因子发生了变化,于是可以沿着节点向上来维护平衡性

AVL树的左旋转和右旋转

LL右旋转

插入元素在最终形成不平衡节点的左侧的左侧,采用LL右旋转自平衡

  • 将y连接在x节点的右孩子上
  • 将T3连接在y的左孩子上
  • 更新y和x的高度值

private Node add(Node root, K key, V value){

    if (root == null){

        size++;
        return new Node(key, value);
    }

    if (root.key.compareTo(key) > 0){
        root.leftNext = add(root.leftNext, key, value);
    }
    else if (root.key.compareTo(key) < 0){
        root.rightNext = add(root.rightNext, key, value);
    }
    else {
        root.value = value;
    }

    root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

    int balanceFactor = getBalanceFactor(root);

    /**
     * LL右旋转满足的条件是:当前节点的平衡因子大于1,且不平衡发生在左子树的左侧,即左子树的平衡因子大于等于0
     */
    if (balanceFactor > 1 && getBalanceFactor(root.leftNext) >= 0){
        return rightRotate(root);
    }

    return root;
}

/**
 * LL右旋转
 *        y                               x
 *       / \                            /   \
 *      x   T4      LL右旋转(y)        z     y
 *     / \       ---------------->    / \   / \
 *    z   T3                        T1  T2 T3 T4
 *   / \
 * T1   T2
 */
private Node rightRotate(Node y){

    Node x = y.leftNext;
    Node T3 = x.rightNext;
    x.rightNext = y;
    y.leftNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

RR左旋转

插入元素在最终形成不平衡节点的右侧的右侧,采用RR左旋转自平衡

  • 将y连接在x节点的左孩子上
  • 将T3连接在y的右孩子上
  • 更新y和x的高度值

private Node add(Node root, K key, V value){

    if (root == null){

        size++;
        return new Node(key, value);
    }

    if (root.key.compareTo(key) > 0){
        root.leftNext = add(root.leftNext, key, value);
    }
    else if (root.key.compareTo(key) < 0){
        root.rightNext = add(root.rightNext, key, value);
    }
    else {
        root.value = value;
    }

    root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

    int balanceFactor = getBalanceFactor(root);

    /**
     * RR左旋转满足的条件是:当前节点的平衡因子小于-1,且不平衡发生在右子树的右侧,即右子树的平衡因子小于等于0
     */
    if (balanceFactor < -1 && getBalanceFactor(root.rightNext) <= 0){
        return leftRotate(root);
    }

    return root;
}

/**
 * RR左旋转
 *     y                                 x
 *    / \                              /   \
 *   T4  x         RR左旋转(y)        y     z
 *     /  \     ---------------->    / \   / \
 *    T3   z                       T4  T3 T1  T2
 *        / \
 *      T1   T2
 */
private Node leftRotate(Node y){

    Node x = y.rightNext;
    Node T3 = x.leftNext;
    x.leftNext = y;
    y.rightNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

LR(先左后右)

插入元素在最终形成不平衡节点的左侧的右侧,采用LR(先左后右)自平衡

  • 先对x进行左旋转

  • 再对z进行右旋转

private Node add(Node root, K key, V value){

    if (root == null){

        size++;
        return new Node(key, value);
    }

    if (root.key.compareTo(key) > 0){
        root.leftNext = add(root.leftNext, key, value);
    }
    else if (root.key.compareTo(key) < 0){
        root.rightNext = add(root.rightNext, key, value);
    }
    else {
        root.value = value;
    }

    root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

    int balanceFactor = getBalanceFactor(root);

    /**
     * LR(先左后右)满足的条件是:当前节点的平衡因子大于1,且不平衡发生在左子树的右侧,即左子树的平衡因子小于0
     */
    if (balanceFactor > 1 && getBalanceFactor(root.leftNext) < 0){
        
        root.leftNext = leftRotate(root.leftNext);
        return rightRotate(root);
    }

    return root;
}

/**
 * LR(先左后右)
 *          y                               y                               z
 *        /  \                             / \                            /   \
 *       x    T4       左旋转(x)          z  T4       右旋转(z)         x      y
 *     /  \        ---------------->     / \      ---------------->    /  \   /  \
 *   T1    z                            x  T3                        T1   T2 T3  T4
 *       /  \                          / \
 *     T2   T3                       T1   T2
 */
private Node leftRotate(Node y){

    Node x = y.rightNext;
    Node T3 = x.leftNext;
    x.leftNext = y;
    y.rightNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

private Node rightRotate(Node y){

    Node x = y.leftNext;
    Node T3 = x.rightNext;
    x.rightNext = y;
    y.leftNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

RL(先右后左)

插入元素在最终形成不平衡节点的右侧的左侧,采用RL(先右后左)自平衡

  • 先对x进行右旋转

  • 再对z进行左旋转

private Node add(Node root, K key, V value){

    if (root == null){

        size++;
        return new Node(key, value);
    }

    if (root.key.compareTo(key) > 0){
        root.leftNext = add(root.leftNext, key, value);
    }
    else if (root.key.compareTo(key) < 0){
        root.rightNext = add(root.rightNext, key, value);
    }
    else {
        root.value = value;
    }

    root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

    int balanceFactor = getBalanceFactor(root);

    /**
     * RL(先右后左)满足的条件是:当前节点的平衡因子小于-1,且不平衡发生在右子树的左侧,即右子树的平衡因子也要大于0
     */
    if (balanceFactor < -1 && getBalanceFactor(root.rightNext) > 0){
        
        root.rightNext = rightRotate(root.rightNext);
        return leftRotate(root);
    }

    return root;
}

/**
 * RL(先右后左)
 *         y                                y                                  z
 *       /  \                             /  \                               /   \
 *      T1   x           右旋转(x)       T1   z           左旋转(z)        y     x
 *          / \     ---------------->        /  \     ---------------->    / \   / \
 *         z  T4                            T2   x                       T1  T2 T3 T4
 *        / \                                   / \
 *      T2  T3                                 T3 T4
 */
private Node leftRotate(Node y){

    Node x = y.rightNext;
    Node T3 = x.leftNext;
    x.leftNext = y;
    y.rightNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

private Node rightRotate(Node y){

    Node x = y.leftNext;
    Node T3 = x.rightNext;
    x.rightNext = y;
    y.leftNext = T3;

    y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
    x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

    return x;
}

AVL树删除元素

public V remove(K key){

    Node res = getNode(root, key);

    if (res == null){
        return null;
    }

    root = remove(root, key);

    return res.value;
}

private Node remove(Node root, K key){

    if (root == null){
        return null;
    }

    Node res;

    if (root.key.compareTo(key) > 0){

        root.leftNext = remove(root.leftNext, key);
        res = root;
    }
    else if (root.key.compareTo(key) < 0){

        root.rightNext = remove(root.rightNext, key);
        res = root;
    }
    else {

        if (root.leftNext == null) {

            Node newRoot = root.rightNext;
            root.rightNext = null;
            size--;

            res = newRoot;
        }
        else if (root.rightNext == null) {

            Node newRoot = root.leftNext;
            root.leftNext = null;
            size--;

            res = newRoot;
        }
        else {

            Node min = Min(root.rightNext);
            
            /**
             * 在删除右子树最小节点时,也要维护平衡因子
             * 为了不在removeMin()方法中再重复书写左右旋转的代码,此处可以直接使用remove()方法来删除min
             */
            min.rightNext = remove(root.rightNext, min.key);
            min.leftNext = root.leftNext;
            root.leftNext = null;
            root.rightNext = null;

            res = min;
        }
    }
    
    /**
     * 如果删除节点后AVL树空了,就不用进行下一步判断了
     */
    if (res == null){
        return null;
    }

    /**
     * AVL删除元素,就是在正常删除以后,计算一下节点新的高度和平衡因子,然后进行左右旋转
     */
    res.height = Math.max(getHeight(res.leftNext), getHeight(res.rightNext)) + 1;

    if (getBalanceFactor(res) > 1 && getBalanceFactor(res.leftNext) >= 0){
        res = rightRotate(res);
    }

    if (getBalanceFactor(res) < -1 && getBalanceFactor(res.rightNext) <= 0){
        res = leftRotate(res);
    }

    if (getBalanceFactor(res) > 1 && getBalanceFactor(res.leftNext) < 0){

        res.leftNext = leftRotate(res.leftNext);
        res = rightRotate(res);
    }

    if (getBalanceFactor(res) < -1 && getBalanceFactor(res.rightNext) > 0){

        res.rightNext = rightRotate(res.rightNext);
        res = leftRotate(res);
    }

    return res;
}

完整的AVL自平衡旋转

import java.util.ArrayList;

interface Map{

    void add(K key, V value);
    V remove(K key);
    boolean contains(K key);
    V get(K key);
    void set(K key, V value);
    int getSize();
    boolean isEmpty();
}

class AVLTree, V> implements Map{

    class Node{

        public K key;
        public V value;
        public Node leftNext;
        public Node rightNext;
        public int height;

        public Node(K key, V value){

            this.key = key;
            this.value = value;
            leftNext = null;
            rightNext = null;
            height = 1;
        }
    }

    private Node root;
    private int size;

    public AVLTree(){

        root = null;
        size = 0;
    }

    @Override
    public int getSize(){

        return size;
    }

    @Override
    public boolean isEmpty(){

        return size == 0;
    }

    private int getHeight(Node node){

        if (node == null){
            return 0;
        }
        else {
            return node.height;
        }
    }

    private int getBalanceFactor(Node node){

        if (node == null){
            return 0;
        }
        else {
            return getHeight(node.leftNext) - getHeight(node.rightNext);
        }
    }

    private boolean isBST(Node root){

        ArrayList ls = new ArrayList<>();
        inOrder(root, ls);

        for (int i = 0; i < ls.size() - 1; i++) {

            if (ls.get(i).compareTo(ls.get(i + 1)) > 0){
                return false;
            }
        }

        return true;
    }

    private void inOrder(Node root, ArrayList ls){

        if (root == null){
            return;
        }

        inOrder(root.leftNext, ls);
        ls.add(root.key);
        inOrder(root.rightNext, ls);
    }

    public boolean isBalanced(){

        return isBalanced(root);
    }

    private boolean isBalanced(Node root){

        if (root == null){
            return true;
        }

        int balanceFactor = getBalanceFactor(root);

        if (Math.abs(balanceFactor) > 1){
            return false;
        }
        else {
            return isBalanced(root.leftNext) && isBalanced(root.rightNext);
        }
    }

    private Node getNode(Node root, K key){

        if (root == null){
            return null;
        }

        if (root.key.compareTo(key) == 0){
            return root;
        }
        else if (root.key.compareTo(key) > 0){
            return getNode(root.leftNext, key);
        }
        else {
            return getNode(root.rightNext, key);
        }
    }

    @Override
    public boolean contains(K key){

        return getNode(root, key) != null;
    }

    @Override
    public V get(K key){

        Node res = getNode(root, key);

        return res == null ? null : res.value;
    }

    @Override
    public void add(K key, V value){

        root = add(root, key, value);
    }

    private Node add(Node root, K key, V value){

        if (root == null){

            size++;
            return new Node(key, value);
        }

        if (root.key.compareTo(key) > 0){
            root.leftNext = add(root.leftNext, key, value);
        }
        else if (root.key.compareTo(key) < 0){
            root.rightNext = add(root.rightNext, key, value);
        }
        else {
            root.value = value;
        }

        root.height = Math.max(getHeight(root.leftNext), getHeight(root.rightNext)) + 1;

        int balanceFactor = getBalanceFactor(root);

        if (balanceFactor > 1 && getBalanceFactor(root.leftNext) >= 0){
            return rightRotate(root);
        }

        if (balanceFactor < -1 && getBalanceFactor(root.rightNext) <= 0){
            return leftRotate(root);
        }

        if (balanceFactor > 1 && getBalanceFactor(root.leftNext) < 0){

            root.leftNext = leftRotate(root.leftNext);
            return rightRotate(root);
        }

        if (balanceFactor < -1 && getBalanceFactor(root.rightNext) > 0){

            root.rightNext = rightRotate(root.rightNext);
            return leftRotate(root);
        }

        return root;
    }

    private Node leftRotate(Node y){

        Node x = y.rightNext;
        Node T3 = x.leftNext;
        x.leftNext = y;
        y.rightNext = T3;

        y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
        x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

        return x;
    }

    private Node rightRotate(Node y){

        Node x = y.leftNext;
        Node T3 = x.rightNext;
        x.rightNext = y;
        y.leftNext = T3;

        y.height = Math.max(getHeight(y.leftNext), getHeight(y.rightNext)) + 1;
        x.height = Math.max(getHeight(x.leftNext), getHeight(x.rightNext)) + 1;

        return x;
    }

    @Override
    public void set(K key, V value) {

        Node res = getNode(root, key);

        if (res == null){
            throw new IllegalArgumentException("键值不存在");
        }
        else {
            res.value = value;
        }
    }

    private Node Min(Node root){

        if (root.leftNext == null){
            return root;
        }

        return Min(root.leftNext);
    }

    @Override
    public V remove(K key){

        Node res = getNode(root, key);

        if (res == null){
            return null;
        }

        root = remove(root, key);

        return res.value;
    }

    private Node remove(Node root, K key){

        if (root == null){
            return null;
        }

        Node res;

        if (root.key.compareTo(key) > 0){

            root.leftNext = remove(root.leftNext, key);
            res = root;
        }
        else if (root.key.compareTo(key) < 0){

            root.rightNext = remove(root.rightNext, key);
            res = root;
        }
        else {

            if (root.leftNext == null) {

                Node newRoot = root.rightNext;
                root.rightNext = null;
                size--;

                res = newRoot;
            }
            else if (root.rightNext == null) {

                Node newRoot = root.leftNext;
                root.leftNext = null;
                size--;

                res = newRoot;
            }
            else {

                Node min = Min(root.rightNext);
                min.rightNext = remove(root.rightNext, min.key);
                min.leftNext = root.leftNext;
                root.leftNext = null;
                root.rightNext = null;

                res = min;
            }
        }
        
        if (res == null){
            return null;
        }

        res.height = Math.max(getHeight(res.leftNext), getHeight(res.rightNext)) + 1;

        if (getBalanceFactor(res) > 1 && getBalanceFactor(res.leftNext) >= 0){
            res = rightRotate(res);
        }

        if (getBalanceFactor(res) < -1 && getBalanceFactor(res.rightNext) <= 0){
            res = leftRotate(res);
        }

        if (getBalanceFactor(res) > 1 && getBalanceFactor(res.leftNext) < 0){

            res.leftNext = leftRotate(res.leftNext);
            res = rightRotate(res);
        }

        if (getBalanceFactor(res) < -1 && getBalanceFactor(res.rightNext) > 0){

            res.rightNext = rightRotate(res.rightNext);
            res = leftRotate(res);
        }

        return res;
    }
}

AVL树实现映射(Map)

class AVLMap, V> implements Map {

    private AVLTree avlMap;

    public AVLMap(){

        avlMap = new AVLTree<>();
    }

    @Override
    public int getSize(){

        return avlMap.getSize();
    }

    @Override
    public boolean isEmpty(){

        return avlMap.isEmpty();
    }

    @Override
    public boolean contains(K key){

        return avlMap.contains(key);
    }

    @Override
    public V get(K key){

        return avlMap.get(key);
    }

    @Override
    public void set(K key, V value){

        avlMap.set(key, value);
    }

    @Override
    public void add(K key, V value){

        avlMap.add(key, value);
    }

    @Override
    public V remove(K key){

        return avlMap.remove(key);
    }
}

......

AVL树实现集合(Set)

interface Set{

    void add(E e);
    void remove(E e);
    boolean contains(E e);
    int getSize();
    boolean isEmpty();
}

class AVLSet implements Set {

    /**
     * AVLTree实现集合,只需要存储Map的键,值的类型定义为Object,添加元素时传入null即可
     */
    private AVLTree avlSet;

    public AVLSet(){

        avlSet = new AVLTree<>();
    }

    @Override
    public int getSize(){

        return avlSet.getSize();
    }

    @Override
    public boolean isEmpty(){

        return avlSet.isEmpty();
    }

    @Override
    public boolean contains(E key){

        return avlSet.contains(key);
    }

    @Override
    public void add(E key){

        avlSet.add(key, null);
    }

    @Override
    public void remove(E key){

        avlSet.remove(key);
    }
}

......