TF2实现语义分割网络DeepLabV3+



"""
Created on 2020/11/29 19:59.

@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""


from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import Add, Concatenate, UpSampling2D, SeparableConv2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape
from tensorflow.keras import Model


# encoder-串行空洞卷积
def Xception(inputs):
    # --------------------------#
    # Entry flow: 4 blocks
    # --------------------------#
    x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # 深度可分离卷积SeparableConv2D是DepthwiseConv2D的升级版
    # https://blog.csdn.net/m0_37617773/article/details/105988668
    x_shortcut = Conv2D(128, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x = Add()([x_sep, x_shortcut])

    x_shortcut = Conv2D(256, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x_sep = Activation('relu')(x)
    x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep_act = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep_act)
    x_sep = BatchNormalization()(x_sep)
    x = Add()([x_sep, x_shortcut])
    x_low_level_feature = x_sep_act

    x_shortcut = Conv2D(728, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x_sep = Activation('relu')(x)
    x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x = Add()([x_sep, x_shortcut])

    # --------------------------#
    # Middle flow: 16 blocks  加深网络
    # --------------------------#
    for i in range(16):
        x_shortcut = x
        x_sep = Activation('relu')(x)
        x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
        x_sep = BatchNormalization()(x_sep)
        x_sep = Activation('relu')(x_sep)
        x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
        x_sep = BatchNormalization()(x_sep)
        x_sep = Activation('relu')(x_sep)
        x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
        x_sep = BatchNormalization()(x_sep)
        x = Add()([x_sep, x_shortcut])

    # --------------------------#
    # Exit flow: 2 blocks
    # --------------------------#
    x_shortcut = Conv2D(1024, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
    x_shortcut = BatchNormalization()(x_shortcut)
    x_sep = Activation('relu')(x)
    x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x_sep = Activation('relu')(x_sep)
    x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
    x_sep = BatchNormalization()(x_sep)
    x = Add()([x_sep, x_shortcut])

    x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(2048, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x, x_low_level_feature


# encoder-并行空洞卷积
def ASPP(x, filter_num, old_filter_num):
    x_pool = GlobalAveragePooling2D()(x)
    x_pool = Reshape((1, 1, old_filter_num))(x_pool)
    x_pool = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_pool)
    x_pool = BatchNormalization()(x_pool)
    x_pool = Activation('relu')(x_pool)
    x_pool = UpSampling2D(size=(16, 16))(x_pool)

    x_1 = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
    x_1 = BatchNormalization()(x_1)
    x_1 = Activation('relu')(x_1)

    x_6 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=6)(x)
    x_6 = BatchNormalization()(x_6)
    x_6 = Activation('relu')(x_6)

    x_12 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=12)(x)
    x_12 = BatchNormalization()(x_12)
    x_12 = Activation('relu')(x_12)

    x_18 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=18)(x)
    x_18 = BatchNormalization()(x_18)
    x_18 = Activation('relu')(x_18)

    x = Concatenate()([x_pool, x_1, x_6, x_12, x_18])
    x = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    inputs = Input(shape=(tif_size, tif_size, bands))

    x, x_low_level_feature = Xception(inputs)  # 底层特征提供细节信息
    x_high_level_feature = ASPP(x, 256, 2048)  # 高层特征提供语义信息

    x_low_level_feature = Conv2D(48, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_low_level_feature)
    x_high_level_feature = UpSampling2D(size=(4, 4))(x_high_level_feature)

    x = Concatenate()([x_low_level_feature, x_high_level_feature])
    x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D(size=(4, 4))(x)
    x = Conv2D(class_num, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='softmax')(x)

    mymodel = Model(inputs, x)
    return mymodel