六、AlexNet实现中文字体识别——隶书和行楷
目录
- 前文
- 中文字体识别——隶书和行楷
- 数据生成器
- 图像显示
- AlexNet模型构建
- AlexNet模型编译与拟合
- 注意:
- GitHub下载地址:
前文
中文字体识别——隶书和行楷
数据生成器
from keras.preprocessing.image import ImageDataGenerator
IMSIZE = 227
validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/ChineseStyle/test',
target_size=(IMSIZE, IMSIZE),
batch_size=200,
class_mode='categorical'
)
train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/ChineseStyle/train',
target_size=(IMSIZE, IMSIZE),
batch_size=200,
class_mode='categorical'
)
图像显示
from matplotlib import pyplot as plt
plt.figure()
fig, ax = plt.subplots(3, 5)
fig.set_figheight(7)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(15): ax[i].imshow(X[i, :, :, :, ])
AlexNet模型构建
from keras.layers import Activation, Conv2D, Dense
from keras.layers import Dropout, Flatten, Input, MaxPooling2D
from keras import Model
input_layer = Input([IMSIZE, IMSIZE, 3])
x = input_layer
x = Conv2D(96, [11, 11], strides=[4, 4], activation='relu')(x)
x = MaxPooling2D([3, 3], strides=[2, 2])(x)
x = Conv2D(256, [5, 5], padding="same", activation='relu')(x)
x = MaxPooling2D([3, 3], strides=[2, 2])(x)
x = Conv2D(384, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(384, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D([3, 3], strides=[2, 2])(x)
x = Flatten()(x)
x = Dense(4096, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(4096, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(2, activation='softmax')(x)
output_layer = x
model = Model(input_layer, output_layer)
model.summary()
AlexNet模型编译与拟合
from keras.optimizers import Adam
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.001), metrics=['accuracy'])
model.fit_generator(train_generator, epochs=20, validation_data=validation_generator)
注意:
因为自己是使用tensorflow-GPU版本,自己电脑是1050Ti,4G显存。实际运行时候batch_size设置为了不到15大小,太大了就显存资源不足。但是batch_size太小,总的数据集较大较多最后消耗时间就较长。
所以为了效率和烧显卡,请酌情考虑
GitHub下载地址:
Tensorflow1.15深度学习