"""
Created on 2020/11/29 19:45.
@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Activation
from tensorflow.keras.layers import Add, UpSampling2D
from tensorflow.keras import Model
def identity_block(input_x, filter_list): # 实线残差块
filters1, filters2, filters3 = filter_list
x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x, input_x])
x = Activation('relu')(x)
return x
def conv_block(input_x, filter_list): # 虚线残差块
filters1, filters2, filters3 = filter_list
x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
x = BatchNormalization()(x)
# 捷径
shortcut = Conv2D(filters=filters3, kernel_size=(1, 1), strides=2, padding='same')(input_x)
shortcut = BatchNormalization()(shortcut)
x = Add()([x, shortcut])
x = Activation('relu')(x)
return x
def get_resnet50_encoder(inputs):
block_list = [3, 4, 6, 3]
# conv1
x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same')(inputs)
f1 = x # f1是hw方向压缩一次的结果
x = BatchNormalization()(x)
x = Activation('relu')(x)
# conv2_x
x = MaxPooling2D(pool_size=(3, 3), strides=2)(x)
x = conv_block(input_x=x, filter_list=[64, 64, 256])
for i in range(block_list[0]-1):
x = identity_block(input_x=x, filter_list=[64, 64, 256])
f2 = x # f1是hw方向压缩两次的结果
# conv3_x
x = conv_block(input_x=x, filter_list=[128, 128, 512])
for i in range(block_list[1]-1):
x = identity_block(input_x=x, filter_list=[128, 128, 512])
f3 = x # f3是hw方向压缩三次的结果
# conv4_x
x = conv_block(input_x=x, filter_list=[256, 256, 1024])
for i in range(block_list[2]-1):
x = identity_block(input_x=x, filter_list=[256, 256, 1024])
f4 = x # f4是hw方向压缩四次的结果
# conv5_x
x = conv_block(input_x=x, filter_list=[512, 512, 2048])
for i in range(block_list[3]-1):
x = identity_block(input_x=x, filter_list=[512, 512, 2048])
f5 = x # f5是hw方向压缩五次的结果
return [f1, f2, f3, f4, f5]
def get_segnet_decoder(feature):
#
x = UpSampling2D(size=(2, 2))(feature)
x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
#
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(512, (3, 3), strides=1, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
#
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(256, (3, 3), strides=1, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
#
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(128, (3, 3), strides=1, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
#
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(64, (3, 3), strides=1, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
#
return x
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
# 输入
inputs = Input(shape=(tif_size, tif_size, bands))
# 编码器
levels = get_resnet50_encoder(inputs)
# 解码器
x = get_segnet_decoder(feature=levels[3])
# 输出
x = Conv2D(class_num, (1, 1), strides=1, padding='same', activation='softmax')(x)
mymodel = Model(inputs, x)
return mymodel