BoTNet:Bottleneck Transformers for Visual Recognition


【GiantPandaCV导语】基于Transformer的骨干网络,同时使用卷积与自注意力机制来保持全局性和局部性。模型在ResNet最后三个BottleNeck中使用了MHSA替换3x3卷积。属于早期的结合CNN+Transformer的工作。简单来讲Non-Local+Self Attention+BottleNeck = BoTNet

引言

本文的发展脉络如下图所示:

实际上沿着Transformer Block改进的方向进行的,与CNN架构也是兼容的。具体结构如下图所示:

两者都遵循了BottleNeck的设计原则,可以有效降低计算量。

设计Transformer中self attention存在几个挑战:

  • 图片尺寸比较大,比如目标检测中分辨率在1024x1024
  • 内存和计算量的占用高,导致训练开销比较大。

本文设计如下:

  • 使用卷积识别底层特征的抽象信息。
  • 使用self attention处理通过卷积层得到的高层信息。

这样可以有效处理大分辨率图像。

方法

BoTNet中MHSA模块如下图所示:

上边的这个MHSA Block是核心创新点,其与Transformer中的MHSA有所不同:

  • 由于处理对象不是一维的,而是类似CNN模型,所以有非常多特性与此相关。
  • 归一化这里并没有使用Layer Norm而是采用的Batch Norm,与CNN一致。
  • 非线性激活,BoTNet使用了三个非线性激活
  • 左侧content-position模块引入了二维的位置编码,这是与Transformer中最大区别。

由于该模块是处理BxCHW的形式,所以难免让人想起来Non Local 操作,这里列出笔者以前绘制的一幅图:

可以看出主要区别就是在于Content-postion模块引入的位置信息。

BoTNet细节设计:

整体的设计和ResNet50几乎一样,唯一不同在于最后一个阶段中三个BottleNeck使用了MHSA模块。具体这样做的原因是Self attention需要消耗巨大的计算量,在模型最后加入时候feature map的size比较小,相对而言计算量比较小。

实验

在目标检测和分割领域性能对比

分辨率改变对BoTNet帮助更大

消融实验-相对位置编码

BoTNet对ResNet系列模型的提升

BoTNet与更大的图片适配

BoTNet与Non-Local Net的比较

与ImageNet上结果比较

模型放缩的影响

显卡香气飘来,又是谷歌的骚操作,将EfficientNet方法放在BoTNet上:

可以看出与期望相符合,Transformer架构带来的性能上限要高于CNN,虽然模型大小比较小的时候性能比较弱,但是模型量变大以后其性能就有了保证。

代码

核心模块:MHSA (由第三方进行实现)

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)

        return out

参考

https://arxiv.org/abs/2101.11605

https://zhuanlan.zhihu.com/p/347849929

https://github.com/leaderj1001/BottleneckTransformers/blob/main/model.py

跑不动ImageNet,想试试Vision Transformer的同学可以看看这个仓库,

https://github.com/pprp/pytorch-cifar-model-zoo

在CIFAR10上测试:

python train.py --model 'botnet' --name "fast_training" --sched 'cosine' --epochs 100 --cutout True --lr 0.1 --bs 128 --nw 4

目前可以在100个epoch内达到验证集91.1%的准确率。