模型处理-08


  模型是神经网络训练优化后得到的成果, 包含了神经网络骨架及学习得到的参数。 PyTorch对于模型的处理提供了丰富的工具, 本节将从模型的生成、 预训练模型的加载和模型保存3个方面进行介绍。

1. 网络模型库: torchvision.models
对于深度学习, torchvision.models库提供了众多经典的网络结构与预训练模型, 例如VGGResNetInception等, 利用这些模型可以快速搭建物体检测网络, 不需要逐层手动实现。 torchvision包与PyTorch相独立, 需要通过pip指令进行安装, 如下:

1 pip install torchvision # 适用于Python 2
2 pip3 install torchvision # 适用于Python 3 

VGG模型为例, 在torchvision.models中, VGG模型的特征层与分类层分别用vgg.featuresvgg.classifier来表示, 每个部分是一个nn.Sequential结构, 可以方便地使用与修改。 下面讲解如何使用torchvision.model模块。

 1 from torch import nn
 2 from torchvision import models
 3 
 4 # 通过torchvision.model直接调用VGG16的网络结构
 5 vgg = models.vgg16()
 6 
 7 # VGG16的特征层包括13个卷积、 13个激活函数ReLU、 5个池化, 一共31层
 8 print(len(vgg.features))
 9 >> 31
10 
11 # VGG16的分类层包括3个全连接、 2个ReLU、 2个Dropout, 一共7层
12 print(len(vgg.classifier))
13 >> 7
14 
15 # 可以通过出现的顺序直接索引每一层
16 print(vgg.classifier[-1])
17 >> Linear(in_features=4096, out_features=1000, bias=True)
18 
19 # 也可以选取某一部分, 如下代表了特征网络的最后一个卷积模组
20 print(vgg.features[24:])
21 >> Sequential(
22     (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
23     (25): ReLU(inplace)
24     (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
25     (27): ReLU(inplace)
26     (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
27     (29): ReLU(inplace)
28     (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
29   )

2. 加载预训练模型
对于计算机视觉的任务, 包括物体检测, 我们通常很难拿到很大的数据集, 在这种情况下重新训练一个新的模型是比较复杂的, 并且不容易调整, 因此, Fine-tune(微调) 是一个常用的选择。 所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型, 在自己的数据集上训练自己的模型。

在具体使用时, 通常有两种情况, 第一种是直接利用torchvision.models中自带的预训练模型, 只需要在使用时赋予pretrained参数为True即可。

1 from torch import nn
2 from torchvision import models
3 
4 # 通过torchvision.model直接调用VGG16的网络结构
5 vgg = models.vgg16(pretrained=True)

第二种是如果想要使用自己的本地预训练模型, 或者之前训练过的模型, 则可以通过model.load_state_dict()函数操作, 具体如下:

 1 import torch
 2 from torch import nn
 3 from torchvision import models
 4 
 5 # 通过torchvision.model直接调用VGG16的网络结构
 6 vgg = models.vgg16()
 7 static_dict = torch.load(" your model path")
 8 
 9 # 利用load_state_dict, 遍历预训练模型的关键字, 如果出现在了VGG中, 则加载预训练参数
10 vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})

通常来讲, 对于不同的检测任务, 卷积网络的前两三层的作用是非常类似的, 都是提取图像的边缘信息等, 因此为了保证模型训练中能够更加稳定, 一般会固定预训练网络的前两三个卷积层而不进行参数的学习。 例如VGG模型, 可以设置前三个卷积模组不进行参数学习, 设置方式如下:

1 for layer in range(10):
2    for p in vgg[layer].parameters():
3       p.requires_grad = False

3. 模型保存

PyTorch中, 参数的保存通过torch.save()函数实现, 可保存对象包括网络模型、 优化器等, 而这些对象的当前状态数据可以通过自身的state_dict()函数获取。

1 torch.save({
2 ‘model’: model.state_dict(),
3 'optimizer': optimizer.state_dict(),
4 'model_path.pth')