GAN生成对抗
Crypko利用GAN生成二次元老婆,那出于好奇(绝不是因为想要二次元老婆),是不是也可以自行训练GAN模型,去自行创造一下。
以下代码基于Pytorch,另外模型定义部分借鉴DCGAN,如果有读者是TensorFlow使用者,对pytorch不熟悉,可以直接去阅读DCGAN的源码(DCGAN是用TensorFlow写的)
1 数据集、模型定义
1.1 数据集定义
数据集资源获取是从[1]得到的:https://pan.baidu.com/s/1gJ00Uipghq991Piq-j29dQ 提取码:2hls
import os
import PIL.Image as Image
from torch.utils.data import Dataset
class AnimeDataset(Dataset):
def __init__(self, data_dir, trans=None, filter_list=None):
"""
Args:
data_dir: 数据目录
trans: 加载时对数据做的变形
"""
super(AnimeDataset, self).__init__()
self.data_dir = data_dir
self.trans = trans
if filter_list is None:
filter_list = ['.jpg', '.jpeg', '.webp', '.bmp']
self.img_names = [name for name in list(filter(
lambda x: x.endswith(tuple(filter_list)), os.listdir(self.data_dir)
))]
def __getitem__(self, index):
path_img = os.path.join(self.data_dir, self.img_names[index])
img = Image.open(path_img).convert('RGB')
if self.trans is not None:
img = self.trans(img)
return img
def __len__(self):
n = len(self.img_names)
if n == 0:
raise Exception('该路径下没有图片,请重新检查')
return n
1.2 Generator
Generator
使用了反卷积(逆卷积、转置卷积),这里做一个简单说明。
对卷积,我们会比较熟悉,相关计算公式易知:
\[w^{'}=\frac{w+2*padding-kernel\_size}{stride}+1\\ h^{'}=\frac{h+2*padding-kernel\_size}{stride}+1 \]\(w^{'}\)和\(h^{'}\)是经过卷积后,图像的宽高。
那反卷积其实是相反的过程,给定卷积后的\(w^{'}\),\(h^{'}\),计算卷积之前的\(w\),\(h\),计算公式其实就是上面两个式子进行移项而已。
class Generator(nn.Module):
class Generator(nn.Module):
def __init__(self, nzd=100, ngf=64, c=3):
"""
Args:
nzd: noisy vector channel dim
ngf: number of generator feature
c: channel
"""
super(Generator, self).__init__()
self.net = nn.Sequential(
# (nzd, 1, 1)->(ngf * 8, 6, 6)
*self.create_block(nzd, ngf * 8, 6, 1, 0),
# (ngf * 8, 6, 6)->(ngf * 4, 12, 12)
*self.create_block(ngf * 8, ngf * 4),
# (ngf * 4, 12, 12)->(ngf * 2, 24, 24)
*self.create_block(ngf * 4, ngf * 2),
# (ngf * 2, 24, 24)->(ngf, 48, 48)
*self.create_block(ngf * 2, ngf),
# (ngf, 48, 48)->(c, 96, 96)
*self.create_block(ngf * 2, ngf, last=True),
)
def create_block(self, in_channel, out_channel, kernel_size=4, stride=2, padding=1, last=False):
layer_list = [nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False), ]
if last:
layer_list.append(nn.Tanh())
else:
layer_list.extend([nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True), ])
return layer_list
def forward(self, X):
return self.net(X)
def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
init_weights(self.modules(), w_mean, w_std, b_mean, b_std)
1.3 Discriminator
class Discriminator(nn.Module):
def __init__(self, ndf=64, c=3):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# (c, 96, 96)->(ndf, 48, 48)
*self.create_block(c, ndf),
# (ndf, 48, 48)->(ndf*2, 24, 24)
*self.create_block(ndf, ndf * 2),
# (ndf*2, 24, 24)->(ndf*4, 12, 12)
*self.create_block(ndf * 2, ndf * 4),
# (ndf*4, 12, 12)->(ndf*8, 6, 6)
*self.create_block(ndf * 4, ndf * 8),
# (ndf * 8, 6, 6)->(ndf * 16, 3, 3)
*self.create_block(ndf * 8, ndf * 16),
# (ndf * 16, 3, 3)->(1, 3, 3)
*self.create_block(ndf * 16, 1, 3, 1, 0, last=True)
)
def forward(self, X):
return self.net(X)
def create_block(self, in_channel, out_channel, kernel_size=4, stride=2, padding=1, last=False):
layer_list = [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)]
if last:
layer_list.append(nn.Sigmoid())
else:
layer_list.extend([nn.BatchNorm2d(out_channel),
nn.LeakyReLU(0.2, inplace=True)])
return layer_list
def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
init_weights(self.modules(), w_mean, w_std, b_mean, b_std)
1.4 公用的函数
import os
import torchvision.transforms as transforms
import imageio
from torch import nn
from torch.utils.data import DataLoader
from my_utils.animedataset import AnimeDataset
def init_weights(modules, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
"""
初始化模型参数
"""
for m in modules:
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weigth.data, w_mean, w_std)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weigth.data, b_mean, b_std)
nn.init.constant_(m.bias.data, 0)
def load_data(filepath, batch_size, trans=None):
"""
加载数据集
"""
train_set = AnimeDataset(filepath, trans)
return DataLoader(train_set, batch_size=batch_size, num_workers=1, shuffle=True)
def init_trans(image_size):
"""
对图片做的变形处理
"""
return transforms.Compose([transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def gen_gif(src, suffix, dst, filename):
"""
从src中读取后缀为suffix的图片,处理成动态图,保存到dst中,文件名是filename
Args:
src: 源目录
suffix: 后缀
dst: 动态图保存的目录
filename: 动态的文件名
"""
# 由于我保存epoch测试图片的文件名格式:{epoch}_epoch.png,所以通过'_'split后的[0]是一个数字
# 根据实际情况自行调整
imgs_epoch = [int(name.split("_")[0]) for name in
list(filter(lambda x: x.endswith(suffix), os.listdir(src)))]
imgs_epoch = sorted(imgs_epoch)
imgs = list()
for i in range(len(imgs_epoch)):
img_name = os.path.join(src, f"{imgs_epoch[i]}{suffix}")
imgs.append(imageio.imread(img_name))
imageio.mimsave(os.path.join(dst, filename), imgs, fps=2)
2 训练
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import os
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from visdom import Visdom
from my_utils.model import Discriminator, Generator
from my_utils.tools import load_data, init_trans, gen_gif
# 定义一些变量方便后序使用
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dirs = {
'img': './img',
'img_data': '../data/anime/face',
'model': './model'
}
image_size, checkpoint_interval, record_loss_interval = 96, 10, 10 # 图片的大小,模型保存的间隔(每隔10个epoch保存一次模型)
real_img_label, fake_img_label = 0.9, 0.1 # 真\假图片的标签
# 超参数(要是显存够大可以将batch_size, nzd, ngf调大一些)
epochs, lr, batch_size, beta1, nzd, ngf, ndf, channel = 20, 1e-4, 10, 0.5, 50, 64, 64, 3
# fixed_noise是用来测试generator net训练效果
fixed_noise = torch.randn(64, nzd, 1, 1, device=device)
viz = Visdom()
if __name__ == '__main__':
# 1 加载数据
train_iter = load_data(dirs['img_data'], batch_size, init_trans(image_size))
# 2 创建网络、损失函数、优化器
loss = nn.BCELoss()
g_net = Generator(nzd, ngf, channel).to(device)
g_optimizer = optim.Adam(g_net.parameters(), lr, betas=(beta1, 0.999))
g_lr_scheduler = optim.lr_scheduler.StepLR(g_optimizer, step_size=8, gamma=0.1)
d_net = Discriminator(ndf, channel).to(device)
d_optimizer = optim.Adam(d_net.parameters(), lr, betas=(beta1, 0.999))
d_lr_scheduler = optim.lr_scheduler.StepLR(d_optimizer, step_size=8, gamma=0.1)
# 3 开始训练
for epoch in range(1, epochs + 1):
Len = len(train_iter) # 内层循环的次数,每个epoch需要迭代的次数
record_times = Len // record_loss_interval # 计算一共需要记录多少次损失
# 可以看到有4列,每一列代表的含义分别是:g_loss, D(x), D(G(x)), D(x)+D(G(x))
loss_record = torch.zeros(size=(record_times, 4))
cnt = 0
print('#' * 30 + f'\n epoch={epoch} start \n' + '#' * 30)
for i, data in enumerate(train_iter):
data = data.to(device)
############################
# (1) Update D network
###########################
d_net.zero_grad()
b_size = len(data)
real_label = torch.full((b_size,), real_img_label, device=device)
fake_label = torch.full((b_size,), fake_img_label, device=device)
noise = torch.randn((b_size, nzd, 1, 1), device=device)
# 通过noise生成一堆假图片fake_img.shape: (b_size, c=3, h=96, w=96)
fake_img = g_net(noise)
a = d_net(data)
b = d_net(fake_img.detach())
loss_d_real = loss(a.view(-1), real_label)
loss_d_fake = loss(b.view(-1), fake_label)
loss_d_real.backward()
loss_d_fake.backward()
d_optimizer.step()
loss_d = loss_d_real + loss_d_fake # 判别网络对真、假图片的损失和
d_x = a.mean().item() # D(x)是判别网络对真图片的打分均值
d_g_x = b.mean().item() # D(G(x))是判别网络对假图片的打分均值
############################
# (2) Update G network
###########################
g_net.zero_grad()
# 经过上面的步骤,判别网络得到升级,再次去给假图片fake_img打分,来更新生成网络
out_d_fake = d_net(fake_img) # shape: (b_size, 1)
loss_g = loss(out_d_fake.view(-1), real_label)
loss_g.backward()
g_optimizer.step()
d_g_x2 = out_d_fake.mean().item() # D(G(x2))
if i % record_loss_interval == 0 and i != 0:
print(f'[{epoch}/{epochs}]\t[{i}/{Len}]\t'
f'Loss_g={loss_g.item():.4f}\tD(x)={d_x}\tD(G(x))={d_g_x:.4f}/{d_g_x2:.4f}')
# 记录损失
loss_record[cnt] = torch.tensor([loss_g.item(), d_x, d_g_x, loss_d])
cnt += 1
# 每个epoch结束,学习率下降
d_lr_scheduler.step()
g_lr_scheduler.step()
# 经过一个epoch后,使用fixed_noise去让g_net生成假图片,看看效果
with torch.no_grad():
fake = g_net(fixed_noise).detach().cpu()
img_grid = vutils.make_grid(fake, padding=2, normalize=True).numpy()
img_grid = np.transpose(img_grid, (1, 2, 0))
plt.imsave(os.path.join(dirs['img'], f'{epoch}_epoch.png'), img_grid)
# 将这一轮的损失使用visdom进行显示
viz.line(Y=loss_record[:, [0, -1]].mean(keepdim=True, dim=0),
X=torch.full(size=(1, 2), fill_value=epoch),
win='g loss & d loss',
update='append',
opts=dict(title='Mean: g loss & d loss', legend=['g', 'd']))
viz.line(Y=loss_record[:, [1, 2]].mean(keepdim=True, dim=0),
X=torch.full(size=(1, 2), fill_value=epoch),
win='D(x) & D(G(x))',
update='append',
opts=dict(title='Mean: D(x) & D(G(x))', legend=['D(x)', 'D(G(x))']))
# 模型保存
if epoch % checkpoint_interval == 0:
checkpoint = {"g_model_state_dict": g_net.state_dict(),
"d_model_state_dict": d_net.state_dict(),
"epoch": epoch}
path_checkpoint = os.path.join(dirs['model'], "checkpoint_{}_epoch.pkl".format(epoch))
torch.save(checkpoint, path_checkpoint)
# suffix是图片文件名后缀,上面代码中有一段是每个epoch生成一张图看效果,文件名是以_epoch.png结尾
gen_gif(src=dirs['img'], suffix='_epoch.png', dst=dirs['img'], filename='动态图.gif')
print("done")
3 效果
上图是基于3000张96*96的头像,训练20个epoch得到的,可以看到效果有在变好,但还是不够好,原因就是受限于GPU:
- 原始数据集5w张96*96的图片,只用了3000张
- batch_size=10,nzd=64, ngf=64,显存不够大,只能如此了
- 由于epoch只有20可以考虑再增加一些
4 参考文档
[1] GAN学习指南:从原理入门到制作生成Demo
[2] DCGAN原理分析与代码解读
[3] Radford, A., et al. (2015) Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. arXiv:1511.06434