"""
Created on 2020/11/29 19:51.
@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.layers import Add, ZeroPadding2D, AveragePooling2D, Lambda, Concatenate
from tensorflow.keras import Model
import tensorflow.keras.backend as K
tf.compat.v1.disable_eager_execution()
IMAGE_ORDERING = 'channels_last'
if IMAGE_ORDERING == 'channels_first': # 'NCHW'
MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last': # 'NHWC'
MERGE_AXIS = -1
def identity_block(input_x, filter_list, dilation_rate=1): # 实线残差块
filters1, filters2, filters3 = filter_list
x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=1, padding='same')(input_x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x, input_x])
x = Activation('relu')(x)
return x
def conv_block(input_x, filter_list, strides=2, dilation_rate=1): # 虚线残差块
filters1, filters2, filters3 = filter_list
x = Conv2D(filters=filters1, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters2, kernel_size=(3, 3), strides=1, padding='same', dilation_rate=dilation_rate)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters3, kernel_size=(1, 1), strides=1, padding='same')(x)
x = BatchNormalization()(x)
# 捷径
shortcut = Conv2D(filters=filters3, kernel_size=(1, 1), strides=strides, padding='same')(input_x)
shortcut = BatchNormalization()(shortcut)
x = Add()([x, shortcut])
x = Activation('relu')(x)
return x
def get_resnet50_encoder(inputs, downsample_factor=16):
if downsample_factor == 16:
block4_stride = 2
block4_dilation = 1
block5_dilation = 2
elif downsample_factor == 8:
block4_stride = 1
block4_dilation = 2
block5_dilation = 4
block_list = [3, 4, 6, 3]
# conv1
x = ZeroPadding2D(padding=(1, 1))(inputs)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=2, padding='valid')(x)
f1 = x
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
# -----------------------------------------------------
x = ZeroPadding2D(padding=(1, 1))(x)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='valid')(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
x = ZeroPadding2D(padding=(1, 1))(x)
x = Conv2D(filters=128, kernel_size=(3, 3), strides=1, padding='valid')(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
x = ZeroPadding2D(padding=(1, 1))(x)
x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
# -----------------------------------------------------
# conv2_x
x = conv_block(input_x=x, filter_list=[64, 64, 256], strides=1)
for i in range(block_list[0] - 1):
x = identity_block(input_x=x, filter_list=[64, 64, 256])
f2 = x
# conv3_x
x = conv_block(input_x=x, filter_list=[128, 128, 512])
for i in range(block_list[1] - 1):
x = identity_block(input_x=x, filter_list=[128, 128, 512])
f3 = x
# conv4_x
x = conv_block(input_x=x, filter_list=[256, 256, 1024], strides=block4_stride)
for i in range(block_list[2] - 1):
x = identity_block(input_x=x, filter_list=[256, 256, 1024], dilation_rate=block4_dilation)
f4 = x
# conv5_x
x = conv_block(input_x=x, filter_list=[512, 512, 2048], strides=1, dilation_rate=block5_dilation)
for i in range(block_list[3] - 1):
x = identity_block(input_x=x, filter_list=[512, 512, 2048], dilation_rate=block5_dilation)
f5 = x
return [f1, f2, f3, f4, f5]
def pool_block(feats, pool_factor):
h = K.int_shape(feats)[1]
w = K.int_shape(feats)[2]
pool_size = strides = [int(np.round(float(h) / pool_factor)), int(np.round(float(w) / pool_factor))]
x = AveragePooling2D(pool_size=pool_size, strides=strides, padding='same', data_format=IMAGE_ORDERING)(feats)
x = Conv2D(filters=512, kernel_size=(1, 1), padding='same', data_format=IMAGE_ORDERING, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Lambda(
lambda x: tf.compat.v1.image.resize_images(x, (K.int_shape(feats)[1], K.int_shape(feats)[2]),
align_corners=True)
)(x)
return x
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
inputs = Input(shape=(tif_size, tif_size, bands))
aux_branch = True # 是否启用辅助损失分支
downsample_factor = 16 # 8 or 16
inputs_size = (tif_size, tif_size, bands)
levels = get_resnet50_encoder(inputs, downsample_factor)
[f1, f2, f3, f4, f5] = levels
# -------------------------------------#
# PSP模块
# 分区域进行池化
# -------------------------------------#
pool_factors = [1, 2, 3, 6]
o = f5
pool_outs = [o]
for p in pool_factors:
pooled = pool_block(o, p)
pool_outs.append(pooled)
o = Concatenate(axis=MERGE_AXIS)(pool_outs)
# -------------------------------------#
o = Conv2D(512, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(o)
o = BatchNormalization()(o)
o = Activation('relu')(o)
o = Dropout(0.1)(o)
o = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(o)
o = Lambda(lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(o)
o = Activation("softmax", name="main")(o)
# 辅助损失分支
if aux_branch:
f4 = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, padding='same', use_bias=False)(f4)
f4 = BatchNormalization()(f4)
f4 = Activation('relu')(f4)
f4 = Dropout(0.1)(f4)
f4 = Conv2D(class_num, (1, 1), data_format=IMAGE_ORDERING, padding='same')(f4)
f4 = Lambda(
lambda x: tf.compat.v1.image.resize_images(x, (inputs_size[1], inputs_size[0]), align_corners=True))(f4)
f4 = Activation("softmax", name="aux")(f4)
model = Model(inputs, [f4, o]) # 输出2个结果:预测时仅用o即可,f4仅用来辅助,结果上o优于f4
return model
else:
model = Model(inputs, [o])
return model