TF2实现轻量级语义分割网络BiSeNet



"""
Created on 2021/1/4 20:25.
@Author: anne
"""
# https://blog.csdn.net/Hanghang_/article/details/108592828 详解bisenet网络结构
# https://blog.csdn.net/TTLoveYuYu/article/details/114372733 详解相关论文


from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import UpSampling2D, ZeroPadding2D, Activation, Reshape
from tensorflow.keras.layers import SeparableConv2D, add, GlobalAveragePooling2D, multiply, concatenate
from tensorflow.keras import Model


# Attention Refinement Module注意力细化模块
def ARM(x, old_filter_num):
    x_branch = GlobalAveragePooling2D()(x)
    x_branch = Reshape((1, 1, old_filter_num))(x_branch)
    x_branch = Conv2D(old_filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
    x_branch = BatchNormalization()(x_branch)
    x_branch = Activation('sigmoid')(x_branch)
    x = multiply([x_branch, x])
    return x


# Feature Fusion Module特征融合模块
def FFM(spatial_path, context_path, num_classes):
    x = concatenate([spatial_path, context_path], axis=-1)
    x = Conv2D(num_classes, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x_branch = GlobalAveragePooling2D()(x)
    x_branch = Reshape((1, 1, num_classes))(x_branch)
    x_branch = Conv2D(num_classes, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
    x_branch = Activation('relu')(x_branch)
    x_branch = Conv2D(num_classes, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
    x_branch = Activation('sigmoid')(x_branch)
    x_mul = multiply([x_branch, x])
    x = add([x_mul, x])
    return x


def spatial_path(x):
    x = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x


# 编码器Xception组建
# ======================================================================
def entry_flow_1(x, filter_num):
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = MaxPooling2D(pool_size=(2, 2))(x_sep)
    x_shortcut = Conv2D(filter_num, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x = add([x_sep, x_shortcut])
    return x


def entry_flow_2and3(x, filter_num):
    x_sep = Activation('relu')(x)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = MaxPooling2D(pool_size=(2, 2))(x_sep)
    x_shortcut = Conv2D(filter_num, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x = add([x_sep, x_shortcut])
    return x


def middle_flow(x, filter_num):
    x_sep = Activation('relu')(x)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x = add([x_sep, x])
    return x
# ======================================================================


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    # 1、input
    inputs = Input(shape=(tif_size, tif_size, bands))

    # 2、encoder-Xception,类似deeplabv3+,也是改进版的Xception
    # ======================================================================
    x = ZeroPadding2D(padding=(1, 1))(inputs)
    x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = entry_flow_1(x, 128)
    x = entry_flow_2and3(x, 256)
    x = entry_flow_2and3(x, 728)
    for i in range(8):
        if i == 7:
            x_exit_first = middle_flow(x, 728)
        else:
            x = middle_flow(x, 728)

    x = Activation('relu')(x_exit_first)
    x = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x_down16 = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x_down16)
    x_shortcut = Conv2D(1024, kernel_size=(1, 1), strides=(2, 2), padding='same')(x_exit_first)
    x_shortcut = BatchNormalization()(x_shortcut)
    x = add([x, x_shortcut])
    x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(2048, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    # ======================================================================

    # 3、decoder-bisenet
    # 3-1 Spatial Path  --> low-level特征
    result_sp = spatial_path(inputs)
    # 3-2 Context Path  --> high-level特征
    x_down32 = Activation('relu')(x)  # Context使用了改进版Xception模型作为backbone
    x_global = GlobalAveragePooling2D()(x_down32)
    x_global = Reshape((1, 1, 2048))(x_global)
    x_down32 = ARM(x_down32, 2048)
    x_down32 = multiply([x_down32, x_global])
    x_down16 = ARM(x_down16, 1024)
    x_down32 = UpSampling2D(size=(4, 4))(x_down32)
    x_down16 = UpSampling2D(size=(2, 2))(x_down16)
    result_cp = concatenate([x_down32, x_down16], axis=-1)
    # 3-3
    x = FFM(result_sp, result_cp, class_num)
    x = UpSampling2D(size=(8, 8))(x)

    # 4、output
    x = Conv2D(class_num, (1, 1), strides=(1, 1), padding='same', activation='softmax')(x)

    mymodel = Model(inputs, x)
    return mymodel


# model = build_model(256, 3, 2)
# model.summary()