TF2实现语义分割网络UNet



"""
Created on 2021/1/26 22:01.
@Author: anne
"""


from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.layers import UpSampling2D, concatenate, Conv2DTranspose
from tensorflow.keras import Model


def build_model(tif_size, bands, class_num):
    from pathlib import Path
    import sys
    print('===== %s =====' % Path(__file__).name)
    print('===== %s =====' % sys._getframe().f_code.co_name)

    # 1 input
    inputs = Input(shape=(tif_size, tif_size, bands))

    init_filters = 64  # 官方卷积核个数

    # 2 encoder
    conv1 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # 3 middle
    conv5 = Conv2D(init_filters*16, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Dropout(0.5)(conv5)
    conv5 = Conv2D(init_filters*16, (3, 3), activation="relu", padding="same")(conv5)
    conv5 = Dropout(0.5)(conv5)

    # 4 decoder
    # 利用反卷积Conv2DTranspose实现解码器
    # unet中的Conv2D的步长strides均采用默认值1,而反卷积层Conv2DTranspose的步长需要设置为2,对应UpSampling2D(size=(2, 2))
    T_conv6 = Conv2DTranspose(init_filters*8, (3, 3), strides=2, activation="relu", padding="same")(conv5)
    merge6 = concatenate([conv4, T_conv6], axis=3)
    conv6 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(merge6)
    conv6 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(conv6)

    T_conv7 = Conv2DTranspose(init_filters*4, (3, 3), strides=2, activation="relu", padding="same")(conv6)
    merge7 = concatenate([conv3, T_conv7], axis=3)
    conv7 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(merge7)
    conv7 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(conv7)

    T_conv8 = Conv2DTranspose(init_filters*2, (3, 3), strides=2, activation="relu", padding="same")(conv7)
    merge8 = concatenate([conv2, T_conv8], axis=3)
    conv8 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(merge8)
    conv8 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(conv8)

    T_conv9 = Conv2DTranspose(init_filters, (3, 3), strides=2, activation="relu", padding="same")(conv8)
    merge9 = concatenate([conv1, T_conv9], axis=3)
    conv9 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(merge9)
    conv9 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(conv9)

    # 5 output
    x = Conv2D(class_num, (1, 1), activation='softmax', name='outputs')(conv9)

    mymodel = Model(inputs, x)
    return mymodel