迁移学习的预训练模型
迁移学习的预训练模型
如果想构建图像分类器但没有足够的训练数据,那么重用预训练模型的较低层通常是个好办法。例如,训练模型来对花的图片进行分类,并使用预先训练的Xception模型。首先,使用TensorFlow Datasets加载数据集:
import tensorflow_datasets as tfds
(test_set, valid_set, train_set), info = tfds.load('tf_flowers', split=['train[:10%]', 'train[10%:25%]', 'train[25%:]'],
as_supervised=True, with_info=True)
dataset_size = info.splits['train'].num_examples
class_names = info.features['label'].names
n_classes = info.features['label'].num_classes
可以通过设置with_info=True获取有关数据集的信息。在这里获得数据集的大小和类的名称。只有一个train数据集,没有测试集或验证集,因此需要拆分训练集。可以在load方法中添加split参数。例如,10%用于测试,15%用于验证,75%用于训练:
现在必须处理图像。CNN需要\(224\times224\)大小的图像,因此需要调整它们的大小。还需要通过Xception的preprocess_input()函数来预处理图像:
import tensorflow as tf
from tensorflow import keras
def preprocess(image, label):
resize_image = tf.image.resize(image, [224, 224])
final_image = keras.applications.xception.preprocess_input(resize_image)
return final_image, label
用这个预处理函数来处理所有三个数据集,对训练集进行乱序,并对所有数据集添加批量处理和预取:
batch_size = 16
train_set = train_set.shuffle(1000)
train_set = train_set.map(preprocess).batch(batch_size).prefetch(1)
valid_set = valid_set.map(preprocess).batch(batch_size).prefetch(1)
test_set = test_set.map(preprocess).batch(batch_size).prefetch(1)
如果要执行一些数据增强,可以更改训练集的预处理功能,向训练集图像添加一些随机变换。例如,使用tf.image.random_crop()随机裁剪图像,使用tf.image.random_filp_left_right()随机水平翻转图像
接下来加载一个在ImageNet上预训练的Xception模型。通过设置include_top=False排除网络的顶部:这排除了全局池化层和密集输出层。然后根据基本模型的输出,添加自己的全局平均池化层,再跟每一个类一个单位的密集输出层使用softmax函数。最后创建Keras模型
base_model = keras.applications.xception.Xception(weights='imagenet',
include_top=False)
avg = keras.layers.GlobalAveragePooling2D()(base_model.output)
output = keras.layers.Dense(n_classes, activation='softmax')(avg)
model = keras.Model(inputs=base_model.input, outputs=output)
# 在预训练开始时冻结预训练层的权重通常是一个好主意:避免破坏预训练层的权重
for layer in base_model.layers:
layer.trainable = False
# 由于创建的模型直接使用基本模型的层,而不是使用base_model对象本身,因此设置base_model.trainable=False无效
#最后编译模型开始训练:
optimizer = keras.optimizers.SGD(lr=.2, momentum=.9, decay=.01)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
history = model.fit(train_set, epochs=5, validation_data=valid_set)
Epoch 1/5
172/172 [==============================] - 11s 53ms/step - loss: 2.1127 - accuracy: 0.8045 - val_loss: 1.4148 - val_accuracy: 0.8457
Epoch 2/5
172/172 [==============================] - 8s 48ms/step - loss: 0.6569 - accuracy: 0.9033 - val_loss: 1.2112 - val_accuracy: 0.8711
Epoch 3/5
172/172 [==============================] - 8s 47ms/step - loss: 0.2501 - accuracy: 0.9495 - val_loss: 1.0145 - val_accuracy: 0.8784
Epoch 4/5
172/172 [==============================] - 8s 48ms/step - loss: 0.1455 - accuracy: 0.9637 - val_loss: 1.0198 - val_accuracy: 0.8711
Epoch 5/5
172/172 [==============================] - 8s 49ms/step - loss: 0.0916 - accuracy: 0.9738 - val_loss: 0.9966 - val_accuracy: 0.8693
对模型进行几个轮次的训练后,其验证集精度达到了85%,并且不再取得很大的进展。这意味着顶层现在已经受过良好的训练,因此准备解冻所有层(或者尝试只解冻顶层)并继续进行训练(在冻结或解冻时不要忘了编译模型)。这次使用低得多的学习率来避免损坏预训练的权重:
for layer in base_model.layers:
layer.trainable = True
optimizer = keras.optimizers.SGD(lr=.01, momentum=.9, decay=.001)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
history = model.fit(train_set, epochs=5, validation_data=valid_set)
Epoch 1/5
172/172 [==============================] - 38s 200ms/step - loss: 0.5472 - accuracy: 0.8227 - val_loss: 0.3796 - val_accuracy: 0.8693
Epoch 2/5
172/172 [==============================] - 34s 200ms/step - loss: 0.1651 - accuracy: 0.9455 - val_loss: 0.3555 - val_accuracy: 0.8875
Epoch 3/5
172/172 [==============================] - 35s 201ms/step - loss: 0.0772 - accuracy: 0.9778 - val_loss: 0.2701 - val_accuracy: 0.8984
Epoch 4/5
172/172 [==============================] - 35s 202ms/step - loss: 0.0444 - accuracy: 0.9873 - val_loss: 0.2782 - val_accuracy: 0.9056
Epoch 5/5
172/172 [==============================] - 35s 205ms/step - loss: 0.0274 - accuracy: 0.9920 - val_loss: 0.3030 - val_accuracy: 0.9111
model.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
block1_conv1 (Conv2D) (None, None, None, 3 864 input_4[0][0]
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, None, None, 3 128 block1_conv1[0][0]
__________________________________________________________________________________________________
block1_conv1_act (Activation) (None, None, None, 3 0 block1_conv1_bn[0][0]
__________________________________________________________________________________________________
block1_conv2 (Conv2D) (None, None, None, 6 18432 block1_conv1_act[0][0]
__________________________________________________________________________________________________
block1_conv2_bn (BatchNormaliza (None, None, None, 6 256 block1_conv2[0][0]
__________________________________________________________________________________________________
block1_conv2_act (Activation) (None, None, None, 6 0 block1_conv2_bn[0][0]
__________________________________________________________________________________________________
block2_sepconv1 (SeparableConv2 (None, None, None, 1 8768 block1_conv2_act[0][0]
__________________________________________________________________________________________________
block2_sepconv1_bn (BatchNormal (None, None, None, 1 512 block2_sepconv1[0][0]
__________________________________________________________________________________________________
block2_sepconv2_act (Activation (None, None, None, 1 0 block2_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block2_sepconv2 (SeparableConv2 (None, None, None, 1 17536 block2_sepconv2_act[0][0]
__________________________________________________________________________________________________
block2_sepconv2_bn (BatchNormal (None, None, None, 1 512 block2_sepconv2[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, None, None, 1 8192 block1_conv2_act[0][0]
__________________________________________________________________________________________________
block2_pool (MaxPooling2D) (None, None, None, 1 0 block2_sepconv2_bn[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, None, None, 1 512 conv2d_12[0][0]
__________________________________________________________________________________________________
add_36 (Add) (None, None, None, 1 0 block2_pool[0][0]
batch_normalization_12[0][0]
__________________________________________________________________________________________________
block3_sepconv1_act (Activation (None, None, None, 1 0 add_36[0][0]
__________________________________________________________________________________________________
block3_sepconv1 (SeparableConv2 (None, None, None, 2 33920 block3_sepconv1_act[0][0]
__________________________________________________________________________________________________
block3_sepconv1_bn (BatchNormal (None, None, None, 2 1024 block3_sepconv1[0][0]
__________________________________________________________________________________________________
block3_sepconv2_act (Activation (None, None, None, 2 0 block3_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block3_sepconv2 (SeparableConv2 (None, None, None, 2 67840 block3_sepconv2_act[0][0]
__________________________________________________________________________________________________
block3_sepconv2_bn (BatchNormal (None, None, None, 2 1024 block3_sepconv2[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, None, None, 2 32768 add_36[0][0]
__________________________________________________________________________________________________
block3_pool (MaxPooling2D) (None, None, None, 2 0 block3_sepconv2_bn[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, None, None, 2 1024 conv2d_13[0][0]
__________________________________________________________________________________________________
add_37 (Add) (None, None, None, 2 0 block3_pool[0][0]
batch_normalization_13[0][0]
__________________________________________________________________________________________________
block4_sepconv1_act (Activation (None, None, None, 2 0 add_37[0][0]
__________________________________________________________________________________________________
block4_sepconv1 (SeparableConv2 (None, None, None, 7 188672 block4_sepconv1_act[0][0]
__________________________________________________________________________________________________
block4_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block4_sepconv1[0][0]
__________________________________________________________________________________________________
block4_sepconv2_act (Activation (None, None, None, 7 0 block4_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block4_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block4_sepconv2_act[0][0]
__________________________________________________________________________________________________
block4_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block4_sepconv2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, None, None, 7 186368 add_37[0][0]
__________________________________________________________________________________________________
block4_pool (MaxPooling2D) (None, None, None, 7 0 block4_sepconv2_bn[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, None, None, 7 2912 conv2d_14[0][0]
__________________________________________________________________________________________________
add_38 (Add) (None, None, None, 7 0 block4_pool[0][0]
batch_normalization_14[0][0]
__________________________________________________________________________________________________
block5_sepconv1_act (Activation (None, None, None, 7 0 add_38[0][0]
__________________________________________________________________________________________________
block5_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv1_act[0][0]
__________________________________________________________________________________________________
block5_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv1[0][0]
__________________________________________________________________________________________________
block5_sepconv2_act (Activation (None, None, None, 7 0 block5_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block5_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv2_act[0][0]
__________________________________________________________________________________________________
block5_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv2[0][0]
__________________________________________________________________________________________________
block5_sepconv3_act (Activation (None, None, None, 7 0 block5_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block5_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv3_act[0][0]
__________________________________________________________________________________________________
block5_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv3[0][0]
__________________________________________________________________________________________________
add_39 (Add) (None, None, None, 7 0 block5_sepconv3_bn[0][0]
add_38[0][0]
__________________________________________________________________________________________________
block6_sepconv1_act (Activation (None, None, None, 7 0 add_39[0][0]
__________________________________________________________________________________________________
block6_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv1_act[0][0]
__________________________________________________________________________________________________
block6_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv1[0][0]
__________________________________________________________________________________________________
block6_sepconv2_act (Activation (None, None, None, 7 0 block6_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block6_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv2_act[0][0]
__________________________________________________________________________________________________
block6_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv2[0][0]
__________________________________________________________________________________________________
block6_sepconv3_act (Activation (None, None, None, 7 0 block6_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block6_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv3_act[0][0]
__________________________________________________________________________________________________
block6_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv3[0][0]
__________________________________________________________________________________________________
add_40 (Add) (None, None, None, 7 0 block6_sepconv3_bn[0][0]
add_39[0][0]
__________________________________________________________________________________________________
block7_sepconv1_act (Activation (None, None, None, 7 0 add_40[0][0]
__________________________________________________________________________________________________
block7_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv1_act[0][0]
__________________________________________________________________________________________________
block7_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv1[0][0]
__________________________________________________________________________________________________
block7_sepconv2_act (Activation (None, None, None, 7 0 block7_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block7_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv2_act[0][0]
__________________________________________________________________________________________________
block7_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv2[0][0]
__________________________________________________________________________________________________
block7_sepconv3_act (Activation (None, None, None, 7 0 block7_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block7_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv3_act[0][0]
__________________________________________________________________________________________________
block7_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv3[0][0]
__________________________________________________________________________________________________
add_41 (Add) (None, None, None, 7 0 block7_sepconv3_bn[0][0]
add_40[0][0]
__________________________________________________________________________________________________
block8_sepconv1_act (Activation (None, None, None, 7 0 add_41[0][0]
__________________________________________________________________________________________________
block8_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv1_act[0][0]
__________________________________________________________________________________________________
block8_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv1[0][0]
__________________________________________________________________________________________________
block8_sepconv2_act (Activation (None, None, None, 7 0 block8_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block8_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv2_act[0][0]
__________________________________________________________________________________________________
block8_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv2[0][0]
__________________________________________________________________________________________________
block8_sepconv3_act (Activation (None, None, None, 7 0 block8_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block8_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv3_act[0][0]
__________________________________________________________________________________________________
block8_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv3[0][0]
__________________________________________________________________________________________________
add_42 (Add) (None, None, None, 7 0 block8_sepconv3_bn[0][0]
add_41[0][0]
__________________________________________________________________________________________________
block9_sepconv1_act (Activation (None, None, None, 7 0 add_42[0][0]
__________________________________________________________________________________________________
block9_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv1_act[0][0]
__________________________________________________________________________________________________
block9_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv1[0][0]
__________________________________________________________________________________________________
block9_sepconv2_act (Activation (None, None, None, 7 0 block9_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block9_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv2_act[0][0]
__________________________________________________________________________________________________
block9_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv2[0][0]
__________________________________________________________________________________________________
block9_sepconv3_act (Activation (None, None, None, 7 0 block9_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block9_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv3_act[0][0]
__________________________________________________________________________________________________
block9_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv3[0][0]
__________________________________________________________________________________________________
add_43 (Add) (None, None, None, 7 0 block9_sepconv3_bn[0][0]
add_42[0][0]
__________________________________________________________________________________________________
block10_sepconv1_act (Activatio (None, None, None, 7 0 add_43[0][0]
__________________________________________________________________________________________________
block10_sepconv1 (SeparableConv (None, None, None, 7 536536 block10_sepconv1_act[0][0]
__________________________________________________________________________________________________
block10_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv1[0][0]
__________________________________________________________________________________________________
block10_sepconv2_act (Activatio (None, None, None, 7 0 block10_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block10_sepconv2 (SeparableConv (None, None, None, 7 536536 block10_sepconv2_act[0][0]
__________________________________________________________________________________________________
block10_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv2[0][0]
__________________________________________________________________________________________________
block10_sepconv3_act (Activatio (None, None, None, 7 0 block10_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block10_sepconv3 (SeparableConv (None, None, None, 7 536536 block10_sepconv3_act[0][0]
__________________________________________________________________________________________________
block10_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv3[0][0]
__________________________________________________________________________________________________
add_44 (Add) (None, None, None, 7 0 block10_sepconv3_bn[0][0]
add_43[0][0]
__________________________________________________________________________________________________
block11_sepconv1_act (Activatio (None, None, None, 7 0 add_44[0][0]
__________________________________________________________________________________________________
block11_sepconv1 (SeparableConv (None, None, None, 7 536536 block11_sepconv1_act[0][0]
__________________________________________________________________________________________________
block11_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv1[0][0]
__________________________________________________________________________________________________
block11_sepconv2_act (Activatio (None, None, None, 7 0 block11_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block11_sepconv2 (SeparableConv (None, None, None, 7 536536 block11_sepconv2_act[0][0]
__________________________________________________________________________________________________
block11_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv2[0][0]
__________________________________________________________________________________________________
block11_sepconv3_act (Activatio (None, None, None, 7 0 block11_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block11_sepconv3 (SeparableConv (None, None, None, 7 536536 block11_sepconv3_act[0][0]
__________________________________________________________________________________________________
block11_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv3[0][0]
__________________________________________________________________________________________________
add_45 (Add) (None, None, None, 7 0 block11_sepconv3_bn[0][0]
add_44[0][0]
__________________________________________________________________________________________________
block12_sepconv1_act (Activatio (None, None, None, 7 0 add_45[0][0]
__________________________________________________________________________________________________
block12_sepconv1 (SeparableConv (None, None, None, 7 536536 block12_sepconv1_act[0][0]
__________________________________________________________________________________________________
block12_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv1[0][0]
__________________________________________________________________________________________________
block12_sepconv2_act (Activatio (None, None, None, 7 0 block12_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block12_sepconv2 (SeparableConv (None, None, None, 7 536536 block12_sepconv2_act[0][0]
__________________________________________________________________________________________________
block12_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv2[0][0]
__________________________________________________________________________________________________
block12_sepconv3_act (Activatio (None, None, None, 7 0 block12_sepconv2_bn[0][0]
__________________________________________________________________________________________________
block12_sepconv3 (SeparableConv (None, None, None, 7 536536 block12_sepconv3_act[0][0]
__________________________________________________________________________________________________
block12_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv3[0][0]
__________________________________________________________________________________________________
add_46 (Add) (None, None, None, 7 0 block12_sepconv3_bn[0][0]
add_45[0][0]
__________________________________________________________________________________________________
block13_sepconv1_act (Activatio (None, None, None, 7 0 add_46[0][0]
__________________________________________________________________________________________________
block13_sepconv1 (SeparableConv (None, None, None, 7 536536 block13_sepconv1_act[0][0]
__________________________________________________________________________________________________
block13_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block13_sepconv1[0][0]
__________________________________________________________________________________________________
block13_sepconv2_act (Activatio (None, None, None, 7 0 block13_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block13_sepconv2 (SeparableConv (None, None, None, 1 752024 block13_sepconv2_act[0][0]
__________________________________________________________________________________________________
block13_sepconv2_bn (BatchNorma (None, None, None, 1 4096 block13_sepconv2[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, None, None, 1 745472 add_46[0][0]
__________________________________________________________________________________________________
block13_pool (MaxPooling2D) (None, None, None, 1 0 block13_sepconv2_bn[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, None, None, 1 4096 conv2d_15[0][0]
__________________________________________________________________________________________________
add_47 (Add) (None, None, None, 1 0 block13_pool[0][0]
batch_normalization_15[0][0]
__________________________________________________________________________________________________
block14_sepconv1 (SeparableConv (None, None, None, 1 1582080 add_47[0][0]
__________________________________________________________________________________________________
block14_sepconv1_bn (BatchNorma (None, None, None, 1 6144 block14_sepconv1[0][0]
__________________________________________________________________________________________________
block14_sepconv1_act (Activatio (None, None, None, 1 0 block14_sepconv1_bn[0][0]
__________________________________________________________________________________________________
block14_sepconv2 (SeparableConv (None, None, None, 2 3159552 block14_sepconv1_act[0][0]
__________________________________________________________________________________________________
block14_sepconv2_bn (BatchNorma (None, None, None, 2 8192 block14_sepconv2[0][0]
__________________________________________________________________________________________________
block14_sepconv2_act (Activatio (None, None, None, 2 0 block14_sepconv2_bn[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_3 (Glo (None, 2048) 0 block14_sepconv2_act[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 5) 10245 global_average_pooling2d_3[0][0]
==================================================================================================
Total params: 20,871,725
Trainable params: 20,817,197
Non-trainable params: 54,528
__________________________________________________________________________________________________