python调用pytorch实现空间注意力分类卷积网络——以猫狗分类为例


目录
  • 程序简介
  • 程序/数据集下载
  • 代码分析

程序简介

项目调用pytorch搭建了基于空间注意力机制的卷积神经网络模型对猫狗图片进行分类,并对空间注意力进行可视化,以增加神经网络的可解释性

空间注意力机制(Spatial Attention)可以看作神经网络对输入图像的局部特征做了一次加权,比如第一眼看见下面这只猫猫对它各部分的在意程度不同

程序/数据集下载

点击进入下载地址

代码分析

导入模块

import glob
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import shutil
import os

简单的数据增强,即水平翻转和图片重定尺寸

def augment(imgPath,flip=False):
    '''
    读取图片,重定尺寸,并做旋转操作,返回图片矩阵
    path:图片路径
    flip:是否翻转
    '''
    #图片矩阵
    img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8),-1)
    img = cv2.resize(img,(50,50))
    #旋转图片矩阵
    if flip:
        img=cv2.flip(img,1)
    return img
img1 = augment(cats[0])
img2 = augment(cats[0],True)
img = np.concatenate([img1,img2],axis=1)
#plt.matshow(img)

空间注意力模型:差不多就是对每个局部像素求了最大值和平均值,拼接起来后经过一层卷积和sigmoid激活

class SpatialAttention(nn.Module):
    def __init__(self, kernelSize=7):
        '''空间注意力'''
        super(SpatialAttention, self).__init__()
        padding = 3
        self.conv1 = nn.Conv2d(2, 1, kernelSize, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgOut = torch.mean(x, dim=1, keepdim=True)
        maxOut, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgOut, maxOut], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

搭建有空间注意力的卷积网络

class ConvBlock(nn.Module):
    def __init__(self, inCh, outCh, kSize=3,stride=1,padding=1):
        '''卷积块'''
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=kSize,stride=stride, padding=padding),#卷积层
            nn.BatchNorm2d(outCh), #批量标准化层
            nn.PReLU(),#PReLU激活层
        )

    def forward(self, x):
        return self.conv(x)
    
class CNN(nn.Module):
    def __init__(self):
        '''空间注意力CNN分类模型,输入为200x200x3的图片'''
        super(CNN, self).__init__()
        self.spacialAtt = SpatialAttention()
        self.conv1 = ConvBlock(3,50)
        self.conv2 = ConvBlock(50,1)
        self.fc = nn.Linear(2500,2)

    def forward(self,x):
        x = x.permute(0,3,1,2)
        spacialW = self.spacialAtt(x)
        x = x.mul(spacialW)
        x = self.conv1(x)
        x = self.conv2(x)
        a,b,c,d = x.shape
        x = x.reshape(a,b*c*d)
        x = self.fc(x)
        return x

def data2numpy(data):
    '''张量转为CPU的numpy类型'''
    return data.detach().cpu().numpy()

# model = CNN()
# model(torch.zeros(3,50,50,3))

循环训练网络,打印出训练过程,图片较少,没有使用测试集,反正只是为了学习思路o_0

steps = 600#迭代次数
batchSize = 300#批处理量
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')#使用GPU CPU
paths = glob.glob("Static/imgs/*.jpg")#所有图片路径
labels = ["cat","dog"]#标签集
#实例化网络
model = CNN()
#多分类交叉熵损失
lossFun = nn.CrossEntropyLoss()
#优化器
opt = torch.optim.Adam(model.parameters(),lr=1e-4)
model.to(device)
#进度条
bar = tqdm(range(steps))
for step in bar:
    model.train()
    #批量取出图片 和 对应标签
    batchPaths = np.random.choice(paths,batchSize,replace=False)
    xBatch = []
    yBatch = []
    for path in batchPaths:
        for label in labels:
            if label in path:
                yBatch.append(labels.index(label))
        if np.random.randint(0,2):
            img = augment(path,True)
        else:
            img = augment(path)
        xBatch.append(img)
    xBatch = np.stack(xBatch,axis=0)
    xBatch = torch.FloatTensor(xBatch).to(device)
    yBatch = torch.LongTensor(yBatch).to(device)
    #训练
    model.zero_grad()
    yPred = model(xBatch)
    loss = lossFun(yPred, yBatch)
    loss.backward()
    opt.step()
    #计算当前训练集的准确率
    labelPred = yPred.argmax(axis=-1)
    acc = accuracy_score(data2numpy(yBatch),data2numpy(labelPred))
    bar.set_description("loss:%.3f acc:%.3f"%(loss,acc))
loss:0.339 acc:0.887: 100%|██████████████████████████████████████████████████████████| 600/600 [01:39<00:00,  6.01it/s]

可视化一下空间注意力的权重,即神经网络到底在乎图片哪一部分,在下面的组合图中右侧,像素点越白,说明神经网络越在乎那个点,虽然还是很抽象。。。但可以看出注意力层已经过滤掉了大部分无关点,留下了狗的部分┭┮﹏┭┮

model.eval()
#随便读取一张图片
x = xBatch.permute(0,3,1,2)
attImg = data2numpy(model.spacialAtt(x)[0,0])
x = data2numpy(x[0]).transpose(1,2,0)
x = x.astype(np.uint8)
#注意力图转3通道标准化
cv2.normalize(attImg, attImg, 0, 255, cv2.NORM_MINMAX)
attImg = np.stack([attImg,attImg,attImg],axis=-1).astype(np.uint8)
combine = np.concatenate([x,attImg],axis=1)
plt.matshow(combine)

本文章只发布于博客园爆米算法,被抄袭后可能排版错乱或下载失效,作者:爆米LiuChen

相关