七、Data Augmentation技巧


目录
  • 前文
  • 数据生成器+数据部分展示
  • 数据增强模型
  • 数据增强模型的编译与拟合
  • GitHub下载地址:

前文

数据生成器+数据部分展示

#猫狗分类。数据增强
#数据生成器生成测试集
from keras.preprocessing.image import ImageDataGenerator

IMSIZE = 128
validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory(
    '../../data/dogs-vs-cats/smallData/validation',
    target_size=(IMSIZE, IMSIZE),
    batch_size=10,
    class_mode='categorical'
)

在这里插入图片描述

# 利用数据增强技术生成的训练集
train_generator = ImageDataGenerator(rescale=1. / 255, shear_range=0.5, rotation_range=30,
                                     zoom_range=0.2, width_shift_range=0.2, height_shift_range=0.2
                                     ).flow_from_directory('../../data/dogs-vs-cats/smallData/train',
                                                           target_size=(IMSIZE, IMSIZE), batch_size=10,
                                                           class_mode='categorical')

在这里插入图片描述

数据来源kaggle的猫狗数据

#展示数据增强后的图像
from matplotlib import pyplot as plt

plt.figure()
fig, ax = plt.subplots(2, 5)
fig.set_figheight(6)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(10): ax[i].imshow(X[i, :, :, ])

在这里插入图片描述

数据增强模型

#数据增强模型
IMSIZE = 128
from keras.layers import BatchNormalization, Conv2D, Dense, Flatten, Input, MaxPooling2D
from keras import Model

n_channel = 100
input_layer = Input([IMSIZE, IMSIZE, 3])
x = input_layer
x =BatchNormalization()(x)
for _ in range(7):
    x =BatchNormalization()(x)
    x =Conv2D(n_channel,[2,2],padding='same',activation='relu')(x)
    x =MaxPooling2D([2,2])(x)

x =Flatten()(x)
x =Dense(2,activation='softmax')(x)
output_layer = x
model = Model(input_layer,output_layer)
model.summary()

在这里插入图片描述

数据增强模型的编译与拟合

#数据增强模型的编译与拟合
from keras.optimizers import Adam
model.compile(loss='categorical_crossentropy',
               optimizer=Adam(lr=0.0001),
               metrics=['accuracy'])
model.fit_generator(train_generator,
                     epochs=200,
                     validation_data=validation_generator)

在这里插入图片描述

GitHub下载地址:

Tensorflow1.15深度学习