"""
Created on 2021/1/4 20:25.
@Author: anne
"""
# https://blog.csdn.net/Hanghang_/article/details/108592828 详解bisenet网络结构
# https://blog.csdn.net/TTLoveYuYu/article/details/114372733 详解相关论文
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import UpSampling2D, ZeroPadding2D, Activation, Reshape
from tensorflow.keras.layers import SeparableConv2D, add, GlobalAveragePooling2D, multiply, concatenate
from tensorflow.keras import Model
# Attention Refinement Module注意力细化模块
def ARM(x, old_filter_num):
x_branch = GlobalAveragePooling2D()(x)
x_branch = Reshape((1, 1, old_filter_num))(x_branch)
x_branch = Conv2D(old_filter_num, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
x_branch = BatchNormalization()(x_branch)
x_branch = Activation('sigmoid')(x_branch)
x = multiply([x_branch, x])
return x
# Feature Fusion Module特征融合模块
def FFM(spatial_path, context_path, num_classes):
x = concatenate([spatial_path, context_path], axis=-1)
x = Conv2D(num_classes, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x_branch = GlobalAveragePooling2D()(x)
x_branch = Reshape((1, 1, num_classes))(x_branch)
x_branch = Conv2D(num_classes, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
x_branch = Activation('relu')(x_branch)
x_branch = Conv2D(num_classes, kernel_size=(1, 1), strides=(1, 1), padding='same')(x_branch)
x_branch = Activation('sigmoid')(x_branch)
x_mul = multiply([x_branch, x])
x = add([x_mul, x])
return x
def spatial_path(x):
x = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
# 编码器Xception组建
# ======================================================================
def entry_flow_1(x, filter_num):
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = MaxPooling2D(pool_size=(2, 2))(x_sep)
x_shortcut = Conv2D(filter_num, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x = add([x_sep, x_shortcut])
return x
def entry_flow_2and3(x, filter_num):
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = MaxPooling2D(pool_size=(2, 2))(x_sep)
x_shortcut = Conv2D(filter_num, kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
x_shortcut = BatchNormalization()(x_shortcut)
x = add([x_sep, x_shortcut])
return x
def middle_flow(x, filter_num):
x_sep = Activation('relu')(x)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x_sep = Activation('relu')(x_sep)
x_sep = SeparableConv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')(x_sep)
x_sep = BatchNormalization()(x_sep)
x = add([x_sep, x])
return x
# ======================================================================
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
# 1、input
inputs = Input(shape=(tif_size, tif_size, bands))
# 2、encoder-Xception,类似deeplabv3+,也是改进版的Xception
# ======================================================================
x = ZeroPadding2D(padding=(1, 1))(inputs)
x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = ZeroPadding2D(padding=(1, 1))(x)
x = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = entry_flow_1(x, 128)
x = entry_flow_2and3(x, 256)
x = entry_flow_2and3(x, 728)
for i in range(8):
if i == 7:
x_exit_first = middle_flow(x, 728)
else:
x = middle_flow(x, 728)
x = Activation('relu')(x_exit_first)
x = SeparableConv2D(728, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x_down16 = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x_down16)
x_shortcut = Conv2D(1024, kernel_size=(1, 1), strides=(2, 2), padding='same')(x_exit_first)
x_shortcut = BatchNormalization()(x_shortcut)
x = add([x, x_shortcut])
x = SeparableConv2D(1536, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(2048, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
# ======================================================================
# 3、decoder-bisenet
# 3-1 Spatial Path --> low-level特征
result_sp = spatial_path(inputs)
# 3-2 Context Path --> high-level特征
x_down32 = Activation('relu')(x) # Context使用了改进版Xception模型作为backbone
x_global = GlobalAveragePooling2D()(x_down32)
x_global = Reshape((1, 1, 2048))(x_global)
x_down32 = ARM(x_down32, 2048)
x_down32 = multiply([x_down32, x_global])
x_down16 = ARM(x_down16, 1024)
x_down32 = UpSampling2D(size=(4, 4))(x_down32)
x_down16 = UpSampling2D(size=(2, 2))(x_down16)
result_cp = concatenate([x_down32, x_down16], axis=-1)
# 3-3
x = FFM(result_sp, result_cp, class_num)
x = UpSampling2D(size=(8, 8))(x)
# 4、output
x = Conv2D(class_num, (1, 1), strides=(1, 1), padding='same', activation='softmax')(x)
mymodel = Model(inputs, x)
return mymodel
# model = build_model(256, 3, 2)
# model.summary()