"""
Created on 2020/11/29 19:59.
@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import Add, Concatenate, UpSampling2D, SeparableConv2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape
from tensorflow.keras import Model
# encoder-串行空洞卷积
def Xception(inputs):
# --------------------------#
# Entry flow: 4 blocks
# --------------------------#
x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# 深度可分离卷积SeparableConv2D是DepthwiseConv2D的升级版
# https://blog.csdn.net/m0_37617773/article/details/105988668
x_shortcut = Conv2D(128, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x_shortcut = Conv2D(256, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep_act = Activation('relu')(x_sep)
x_sep = SeparableConv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep_act)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x_low_level_feature = x_sep_act
x_shortcut = Conv2D(728, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(2, 2), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
# --------------------------#
# Middle flow: 16 blocks 加深网络
# --------------------------#
for i in range(16):
x_shortcut = x
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
# --------------------------#
# Exit flow: 2 blocks
# --------------------------#
x_shortcut = Conv2D(1024, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = Add()([x_sep, x_shortcut])
x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(2048, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x, x_low_level_feature
# encoder-并行空洞卷积
def ASPP(x, filter_num, old_filter_num):
x_pool = GlobalAveragePooling2D()(x)
x_pool = Reshape((1, 1, old_filter_num))(x_pool)
x_pool = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_pool)
x_pool = BatchNormalization()(x_pool)
x_pool = Activation('relu')(x_pool)
x_pool = UpSampling2D(size=(16, 16))(x_pool)
x_1 = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x_1 = BatchNormalization()(x_1)
x_1 = Activation('relu')(x_1)
x_6 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=6)(x)
x_6 = BatchNormalization()(x_6)
x_6 = Activation('relu')(x_6)
x_12 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=12)(x)
x_12 = BatchNormalization()(x_12)
x_12 = Activation('relu')(x_12)
x_18 = Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation_rate=18)(x)
x_18 = BatchNormalization()(x_18)
x_18 = Activation('relu')(x_18)
x = Concatenate()([x_pool, x_1, x_6, x_12, x_18])
x = Conv2D(filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
inputs = Input(shape=(tif_size, tif_size, bands))
x, x_low_level_feature = Xception(inputs) # 底层特征提供细节信息
x_high_level_feature = ASPP(x, 256, 2048) # 高层特征提供语义信息
x_low_level_feature = Conv2D(48, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_low_level_feature)
x_high_level_feature = UpSampling2D(size=(4, 4))(x_high_level_feature)
x = Concatenate()([x_low_level_feature, x_high_level_feature])
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = UpSampling2D(size=(4, 4))(x)
x = Conv2D(class_num, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='softmax')(x)
mymodel = Model(inputs, x)
return mymodel