"""
Created on 2021/1/26 22:01.
@Author: anne
"""
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.layers import UpSampling2D, concatenate, Conv2DTranspose
from tensorflow.keras import Model
def build_model(tif_size, bands, class_num):
from pathlib import Path
import sys
print('===== %s =====' % Path(__file__).name)
print('===== %s =====' % sys._getframe().f_code.co_name)
# 1 input
inputs = Input(shape=(tif_size, tif_size, bands))
init_filters = 64 # 官方卷积核个数
# 2 encoder
conv1 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(inputs)
conv1 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(pool1)
conv2 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(pool2)
conv3 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(pool3)
conv4 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
# 3 middle
conv5 = Conv2D(init_filters*16, (3, 3), activation="relu", padding="same")(pool4)
conv5 = Dropout(0.5)(conv5)
conv5 = Conv2D(init_filters*16, (3, 3), activation="relu", padding="same")(conv5)
conv5 = Dropout(0.5)(conv5)
# 4 decoder
# 利用反卷积Conv2DTranspose实现解码器
# unet中的Conv2D的步长strides均采用默认值1,而反卷积层Conv2DTranspose的步长需要设置为2,对应UpSampling2D(size=(2, 2))
T_conv6 = Conv2DTranspose(init_filters*8, (3, 3), strides=2, activation="relu", padding="same")(conv5)
merge6 = concatenate([conv4, T_conv6], axis=3)
conv6 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(merge6)
conv6 = Conv2D(init_filters*8, (3, 3), activation="relu", padding="same")(conv6)
T_conv7 = Conv2DTranspose(init_filters*4, (3, 3), strides=2, activation="relu", padding="same")(conv6)
merge7 = concatenate([conv3, T_conv7], axis=3)
conv7 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(merge7)
conv7 = Conv2D(init_filters*4, (3, 3), activation="relu", padding="same")(conv7)
T_conv8 = Conv2DTranspose(init_filters*2, (3, 3), strides=2, activation="relu", padding="same")(conv7)
merge8 = concatenate([conv2, T_conv8], axis=3)
conv8 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(merge8)
conv8 = Conv2D(init_filters*2, (3, 3), activation="relu", padding="same")(conv8)
T_conv9 = Conv2DTranspose(init_filters, (3, 3), strides=2, activation="relu", padding="same")(conv8)
merge9 = concatenate([conv1, T_conv9], axis=3)
conv9 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(merge9)
conv9 = Conv2D(init_filters, (3, 3), activation="relu", padding="same")(conv9)
# 5 output
x = Conv2D(class_num, (1, 1), activation='softmax', name='outputs')(conv9)
mymodel = Model(inputs, x)
return mymodel