第四次作业:猫狗大战挑战赛
使用VGG模型进行猫狗大战
一、
首先,下载数据,然后进行数据处理,在使用CNN处理图像时,需要进行预处理。图片将被整理成 的大小,同时还将进行归一化处理。
torchvision 支持对输入数据进行一些复杂的预处理/变换
这里创建vgg 模型对输入的5个图片利用VGG模型进行预测,同时,使用softmax对结果进行处理,随后展示了识别结果。可以看到,识别结果是比较非常准确的。
接下来修改最后一层,冻结前面层的参数,将required_grad设置为False,这样前面层的权重就不会自动更新。
然后训练并测试全连接层,创建损失函数和优化器、训练模型,共训练一千八百个样本,其中九百张猫九百张狗。
然后测试模型,共两千个测试样本,由于GPU不能用了,导致测试了整整三十九分钟,还是十分痛苦的...
最后对接过进行主观分析
二、猫狗大战实战:
先上传解压数据集,先通过train,valid来训练自己的模型,并测试准确率
然后由文件创建数据集,装载数据集。
创建vgg模型:
修改最后一层,nn.linear层由1000类替换为2类,然后将required_gard 设置为false,冻结前面层的参数
这样反向传播训练梯度时前面层的的权重就不会自动更新,只更新最后一层。
然后训练并测试全连接层:创建顺势函数和优化器,然后训练并测试模型
结果截图