"""
Created on 2020/11/29 19:42.
@Author: yubaby@anne
@Email: yhaif@foxmail.com
"""
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.layers import Conv2DTranspose, Add, concatenate
from tensorflow.keras import Model
def residual_block(input_x, input_filters, is_activate=False):
x = BatchNormalization()(input_x)
x = Activation('relu')(x)
x = Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
x = Add()([x, input_x])
if is_activate:
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
init_filters = 64
is_Dropout = True
# input
inputs = Input(shape=(tif_size, tif_size, bands))
# encoder
conv1 = Conv2D(filters=init_filters * 1, kernel_size=3, strides=1, padding='same')(inputs)
conv1 = residual_block(conv1, init_filters * 1)
conv1 = residual_block(conv1, init_filters * 1, True)
pool1 = MaxPooling2D(pool_size=2)(conv1)
if is_Dropout:
pool1 = Dropout(rate=0.25)(pool1)
conv2 = Conv2D(filters=init_filters * 2, kernel_size=3, strides=1, padding='same')(pool1)
conv2 = residual_block(conv2, init_filters * 2)
conv2 = residual_block(conv2, init_filters * 2, True)
pool2 = MaxPooling2D(pool_size=2)(conv2)
if is_Dropout:
pool2 = Dropout(rate=0.5)(pool2)
conv3 = Conv2D(filters=init_filters * 4, kernel_size=3, strides=1, padding='same')(pool2)
conv3 = residual_block(conv3, init_filters * 4)
conv3 = residual_block(conv3, init_filters * 4, True)
pool3 = MaxPooling2D(pool_size=2)(conv3)
if is_Dropout:
pool3 = Dropout(rate=0.5)(pool3)
conv4 = Conv2D(filters=init_filters * 8, kernel_size=3, strides=1, padding='same')(pool3)
conv4 = residual_block(conv4, init_filters * 8)
conv4 = residual_block(conv4, init_filters * 8, True)
pool4 = MaxPooling2D(pool_size=2)(conv4)
if is_Dropout:
pool4 = Dropout(rate=0.5)(pool4)
# middle
convM = Conv2D(filters=init_filters * 16, kernel_size=3, strides=1, padding='same')(pool4)
convM = residual_block(convM, init_filters * 16)
convM = residual_block(convM, init_filters * 16, True)
# decoder
up4 = Conv2DTranspose(filters=init_filters * 8, kernel_size=3, strides=2, padding='same')(convM)
up4 = concatenate([up4, conv4])
if is_Dropout:
up4 = Dropout(rate=0.5)(up4)
deconv4 = Conv2D(filters=init_filters * 8, kernel_size=3, strides=1, padding='same')(up4)
deconv4 = residual_block(deconv4, init_filters * 8)
deconv4 = residual_block(deconv4, init_filters * 8, True)
up3 = Conv2DTranspose(filters=init_filters * 4, kernel_size=3, strides=2, padding='same')(deconv4)
up3 = concatenate([up3, conv3])
if is_Dropout:
up3 = Dropout(rate=0.5)(up3)
deconv3 = Conv2D(filters=init_filters * 4, kernel_size=3, strides=1, padding='same')(up3)
deconv3 = residual_block(deconv3, init_filters * 4)
deconv3 = residual_block(deconv3, init_filters * 4, True)
up2 = Conv2DTranspose(filters=init_filters * 2, kernel_size=3, strides=2, padding='same')(deconv3)
up2 = concatenate([up2, conv2])
if is_Dropout:
up2 = Dropout(rate=0.5)(up2)
deconv2 = Conv2D(filters=init_filters * 2, kernel_size=3, strides=1, padding='same')(up2)
deconv2 = residual_block(deconv2, init_filters * 2)
deconv2 = residual_block(deconv2, init_filters * 2, True)
up1 = Conv2DTranspose(filters=init_filters * 1, kernel_size=3, strides=2, padding='same')(deconv2)
up1 = concatenate([up1, conv1])
if is_Dropout:
up1 = Dropout(rate=0.5)(up1)
deconv1 = Conv2D(filters=init_filters * 1, kernel_size=3, strides=1, padding='same')(up1)
deconv1 = residual_block(deconv1, init_filters * 1)
deconv1 = residual_block(deconv1, init_filters * 1, True)
# output
x = Conv2D(class_num, (1, 1), activation='softmax', name='outputs')(deconv1)
mymodel = Model(inputs, x)
return mymodel