TF2实现语义分割网络PSPNet



"""
Created on 2020/11/29 19:51.

@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""


import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.layers import Add, ZeroPadding2D, AveragePooling2D, Lambda, Concatenate
from tensorflow.keras import Model
import tensorflow.keras.backend as K


tf.compat.v1.disable_eager_execution()
IMAGE_ORDERING = 'channels_last'
if IMAGE_ORDERING == 'channels_first':  # 'NCHW'
    MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':  # 'NHWC'
    MERGE_AXIS = -1


def identity_block(input_x, filter_list, dilation_rate=1):  # 实线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    x = Add()([x, input_x])
    x = Activation('relu')(x)

    return x


def conv_block(input_x, filter_list, strides=2, dilation_rate=1):  # 虚线残差块
    filters1, filters2, filters3 = filter_list

    x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    # 捷径
    shortcut = Conv2D(filters=filters3, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
    shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x


def get_resnet50_encoder(inputs, downsample_factor=16):
    if downsample_factor == 16:
        block4_stride = 2
        block4_dilation = 1
        block5_dilation = 2
    elif downsample_factor == 8:
        block4_stride = 1
        block4_dilation = 2
        block5_dilation = 4

    block_list = [3, 4, 6, 3]

    # conv1
    x = ZeroPadding2D(padding=(1, 1))(inputs)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=2, padding='valid')(x)
    f1 = x
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)
    # -----------------------------------------------------
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='valid')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(filters=128, kernel_size=(3, 3), strides=1, padding='valid')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
    # -----------------------------------------------------

    # conv2_x
    x = conv_block(input_x=x, filter_list=[64, 64, 256], strides=1)
    for i in range(block_list[0] - 1):
        x = identity_block(input_x=x, filter_list=[64, 64, 256])
    f2 = x

    # conv3_x
    x = conv_block(input_x=x, filter_list=[128, 128, 512])
    for i in range(block_list[1] - 1):
        x = identity_block(input_x=x, filter_list=[128, 128, 512])
    f3 = x

    # conv4_x
    x = conv_block(input_x=x, filter_list=[256, 256, 1024], strides=block4_stride)
    for i in range(block_list[2] - 1):
        x = identity_block(input_x=x, filter_list=[256, 256, 1024], dilation_rate=block4_dilation)
    f4 = x

    # conv5_x
    x = conv_block(input_x=x, filter_list=[512, 512, 2048], strides=1, dilation_rate=block5_dilation)
    for i in range(block_list[3] - 1):
        x = identity_block(input_x=x, filter_list=[512, 512, 2048], dilation_rate=block5_dilation)
    f5 = x

    return [f1, f2, f3, f4, f5]


def pool_block(feats, pool_factor):
    h = K.int_shape(feats)[1]
    w = K.int_shape(feats)[2]
    pool_size = strides = [int(np.round(float(h) / pool_factor)), int(np.round(float(w) / pool_factor))]
    x = AveragePooling2D(pool_size=pool_size, strides=strides, padding='same', data_format=IMAGE_ORDERING)(feats)
    x = Conv2D(filters=512, kernel_size=(1, 1), padding='same', data_format=IMAGE_ORDERING, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Lambda(
        lambda x: tf.compat.v1.image.resize_images(x, (K.int_shape(feats)[1], K.int_shape(feats)[2]),
        align_corners=True)
    )(x)
    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    inputs = Input(shape=(tif_size, tif_size, bands))

    aux_branch = True  # 是否启用辅助损失分支
    downsample_factor = 16  # 8 or 16
    inputs_size = (tif_size, tif_size, bands)

    levels = get_resnet50_encoder(inputs, downsample_factor)
    [f1, f2, f3, f4, f5] = levels

    # -------------------------------------#
    #	PSP模块
    #	分区域进行池化
    # -------------------------------------#
    pool_factors = [1, 2, 3, 6]
    o = f5
    pool_outs = [o]
    for p in pool_factors:
        pooled = pool_block(o, p)
        pool_outs.append(pooled)
    o = Concatenate(axis=MERGE_AXIS)(pool_outs)
    # -------------------------------------#

    o = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(o)
    o = BatchNormalization()(o)
    o = Activation('relu')(o)
    o = Dropout(0.1)(o)

    o = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(o)
    o = Lambda(lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(o)
    o = Activation("softmax", name="main")(o)

    # 辅助损失分支
    if aux_branch:
        f4 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(f4)
        f4 = BatchNormalization()(f4)
        f4 = Activation('relu')(f4)
        f4 = Dropout(0.1)(f4)
        f4 = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(f4)
        f4 = Lambda(
            lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(f4)
        f4 = Activation("softmax", name="aux")(f4)
        model = Model(inputs, [f4, o])  # 输出2个结果:预测时仅用o即可,f4仅用来辅助,结果上o优于f4
        return model
    else:
        model = Model(inputs, [o])
        return model