TF2实现语义分割网络Res-SegNet



"""
Created on 2020/11/29 19:45.

@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""


from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D,  BatchNormalization, Activation
from tensorflow.keras.layers import Add, UpSampling2D
from tensorflow.keras import Model


def identity_block(input_x, filter_list):  # 实线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    x = Add()([x, input_x])
    x = Activation('relu')(x)

    return x


def conv_block(input_x, filter_list):  # 虚线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    # 捷径
    shortcut = Conv2D(filters=filters3, kernel_size=(1, 1), strides=2, padding='same')(input_x)
    shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x


def get_resnet50_encoder(inputs):
    block_list = [3, 4, 6, 3]

    # conv1
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same')(inputs)
    f1 = x  # f1是hw方向压缩一次的结果
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # conv2_x
    x = MaxPooling2D(pool_size=(3, 3), strides=2)(x)
    x = conv_block(input_x=x, filter_list=[64, 64, 256])
    for i in range(block_list[0]-1):
        x = identity_block(input_x=x, filter_list=[64, 64, 256])
    f2 = x  # f1是hw方向压缩两次的结果

    # conv3_x
    x = conv_block(input_x=x, filter_list=[128, 128, 512])
    for i in range(block_list[1]-1):
        x = identity_block(input_x=x, filter_list=[128, 128, 512])
    f3 = x  # f3是hw方向压缩三次的结果

    # conv4_x
    x = conv_block(input_x=x, filter_list=[256, 256, 1024])
    for i in range(block_list[2]-1):
        x = identity_block(input_x=x, filter_list=[256, 256, 1024])
    f4 = x  # f4是hw方向压缩四次的结果

    # conv5_x
    x = conv_block(input_x=x, filter_list=[512, 512, 2048])
    for i in range(block_list[3]-1):
        x = identity_block(input_x=x, filter_list=[512, 512, 2048])
    f5 = x  # f5是hw方向压缩五次的结果

    return [f1, f2, f3, f4, f5]


def get_segnet_decoder(feature):
    #
    x = UpSampling2D(size=(2, 2))(feature)
    x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, (3, 3), strides=1, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    #
    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    # 输入
    inputs = Input(shape=(tif_size, tif_size, bands))
    # 编码器
    levels = get_resnet50_encoder(inputs)
    # 解码器
    x = get_segnet_decoder(feature=levels[3])
    # 输出
    x = Conv2D(class_num, (1, 1), strides=1, padding='same', activation='softmax')(x)

    mymodel = Model(inputs, x)
    return mymodel