TF2实现语义分割网络Res-UNet



"""
Created on 2020/11/29 19:42.

@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""


from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.layers import Conv2DTranspose, Add, concatenate
from tensorflow.keras import Model


def residual_block(input_x, input_filters, is_activate=False):
    x = BatchNormalization()(input_x)
    x = Activation('relu')(x)

    x = Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
    x = Add()([x, input_x])
    if is_activate:
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    return x


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    init_filters = 64
    is_Dropout = True

    # input
    inputs = Input(shape=(tif_size, tif_size, bands))

    # encoder
    conv1 = Conv2D(filters=init_filters * 1, kernel_size=3, strides=1, padding='same')(inputs)
    conv1 = residual_block(conv1, init_filters * 1)
    conv1 = residual_block(conv1, init_filters * 1, True)
    pool1 = MaxPooling2D(pool_size=2)(conv1)
    if is_Dropout:
        pool1 = Dropout(rate=0.25)(pool1)

    conv2 = Conv2D(filters=init_filters * 2, kernel_size=3, strides=1, padding='same')(pool1)
    conv2 = residual_block(conv2, init_filters * 2)
    conv2 = residual_block(conv2, init_filters * 2, True)
    pool2 = MaxPooling2D(pool_size=2)(conv2)
    if is_Dropout:
        pool2 = Dropout(rate=0.5)(pool2)

    conv3 = Conv2D(filters=init_filters * 4, kernel_size=3, strides=1, padding='same')(pool2)
    conv3 = residual_block(conv3, init_filters * 4)
    conv3 = residual_block(conv3, init_filters * 4, True)
    pool3 = MaxPooling2D(pool_size=2)(conv3)
    if is_Dropout:
        pool3 = Dropout(rate=0.5)(pool3)

    conv4 = Conv2D(filters=init_filters * 8, kernel_size=3, strides=1, padding='same')(pool3)
    conv4 = residual_block(conv4, init_filters * 8)
    conv4 = residual_block(conv4, init_filters * 8, True)
    pool4 = MaxPooling2D(pool_size=2)(conv4)
    if is_Dropout:
        pool4 = Dropout(rate=0.5)(pool4)

    # middle
    convM = Conv2D(filters=init_filters * 16, kernel_size=3, strides=1, padding='same')(pool4)
    convM = residual_block(convM, init_filters * 16)
    convM = residual_block(convM, init_filters * 16, True)

    #  decoder
    up4 = Conv2DTranspose(filters=init_filters * 8, kernel_size=3, strides=2, padding='same')(convM)
    up4 = concatenate([up4, conv4])
    if is_Dropout:
        up4 = Dropout(rate=0.5)(up4)
    deconv4 = Conv2D(filters=init_filters * 8, kernel_size=3, strides=1, padding='same')(up4)
    deconv4 = residual_block(deconv4, init_filters * 8)
    deconv4 = residual_block(deconv4, init_filters * 8, True)

    up3 = Conv2DTranspose(filters=init_filters * 4, kernel_size=3, strides=2, padding='same')(deconv4)
    up3 = concatenate([up3, conv3])
    if is_Dropout:
        up3 = Dropout(rate=0.5)(up3)
    deconv3 = Conv2D(filters=init_filters * 4, kernel_size=3, strides=1, padding='same')(up3)
    deconv3 = residual_block(deconv3, init_filters * 4)
    deconv3 = residual_block(deconv3, init_filters * 4, True)

    up2 = Conv2DTranspose(filters=init_filters * 2, kernel_size=3, strides=2, padding='same')(deconv3)
    up2 = concatenate([up2, conv2])
    if is_Dropout:
        up2 = Dropout(rate=0.5)(up2)
    deconv2 = Conv2D(filters=init_filters * 2, kernel_size=3, strides=1, padding='same')(up2)
    deconv2 = residual_block(deconv2, init_filters * 2)
    deconv2 = residual_block(deconv2, init_filters * 2, True)

    up1 = Conv2DTranspose(filters=init_filters * 1, kernel_size=3, strides=2, padding='same')(deconv2)
    up1 = concatenate([up1, conv1])
    if is_Dropout:
        up1 = Dropout(rate=0.5)(up1)
    deconv1 = Conv2D(filters=init_filters * 1, kernel_size=3, strides=1, padding='same')(up1)
    deconv1 = residual_block(deconv1, init_filters * 1)
    deconv1 = residual_block(deconv1, init_filters * 1, True)

    # output
    x = Conv2D(class_num, (1, 1), activation='softmax', name='outputs')(deconv1)

    mymodel = Model(inputs, x)
    return mymodel