Data-Free Quantization代码实现


(本文首发于公众号,没事来逛逛)

前面两篇文章介绍 Data-Free Quantization,这篇文章准备用 pytorch 实现一遍 weight equalize 算法,并捋一下刚踩的新坑。

Weight Equalization

其实 pytorch 官方有一个 weight equalize 的实现 (参考:https://github.com/pytorch/pytorch/blob/v1.8.1/torch/quantization/_equalize.py),但这个实现没法用于 depthwise conv,因此这篇文章准备自己实现一个 weight equalize。

新踩的坑

然而,就当我信心满满准备用 mobilenet 跑一波效果时,突发又踩到另一个坑。

记得我在之前的https://github.com/pytorch/pytorch/blob/v1.8.1/torch/fx/_experimental/fuser.py),由于涉及到 FX 的用法,之后有机会再细讲。

整个算法框架如下:

def equalize(model, inplace=False):

    if not inplace:
        model = deepcopy(model)

    model.eval()
    # 提取模型graph,方便匹配卷积对
    model = torch.fx.symbolic_trace(model)
    
    # 步骤1:fuse BN
    model = fuse(model)
    
    # 步骤2:寻找适合equalize的卷积对
    paired_modules_list = _find_module_pairs(model)

    name_to_module = {}
    name_set = {name for pair in paired_modules_list for name in pair}

    for name, module in model.named_modules():
        if name in name_set:
            name_to_module[name] = module
            
    # 步骤3:每个卷积对进行equalize
    for i, pair in enumerate(paired_modules_list):

        print("equalize: ", pair)

        if len(pair) == 2:
            _cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
        
        elif len(pair) == 3:
            _cross_layer_depthwise_equalization(name_to_module[pair[0]], name_to_module[pair[1]], name_to_module[pair[2]])

    return model

接下来是踩坑重点:寻找合适的卷积对,这部分是需要分 conv 和 fc 两种情况实现的:

def _find_module_pairs(model):
    name_modules = dict(model.named_modules())
    module_pair_lists = []
    
    for node in model.graph.nodes:
        if node.op == "call_module":  # "call_module"表示torch.nn中定义的op
            module = name_modules[node.target]
            
            if type(module) == torch.nn.Conv2d and \
                 module.groups == 1: # 第一个卷积默认是普通卷积,group=1
                layer_group = _find_conv_downstream_layer_to_scale(node, name_modules)
                if len(layer_group) != 0:
                    module_pair_lists.append(layer_group)
            if type(module) == torch.nn.Linear:
                layer_group = _find_fc_downstream_layer_to_scale(node, name_modules)
                if len(layer_group) != 0:
                    module_pair_lists.append(layer_group)
    
    module_pair_lists = [pair for pair in module_pair_lists if len(pair) == 2 or len(pair) == 3]
    return module_pair_lists
  

def _find_conv_downstream_layer_to_scale(cur_node, name_modules):

    layer_group = []
    # 匹配conv->(relu)->conv,匹配到的话,就把卷积对放到layer_group
    if _match_conv_conv(cur_node, name_modules, layer_group):
        return layer_group
    # 匹配conv->(relu)->dw conv->(relu)->conv,匹配到的话,就把卷积对放到layer_group
    elif _match_conv_dwconv_conv(cur_node, name_modules, layer_group):
        return layer_group
    else:
        return []

这部分代码比较杂,我加了一些注释,希望能帮助需要的同学看懂。当然在此之前最好先熟悉一下 FX 中的一些 api 的使用。

接下来是另外两个核心函数:一个是对 conv-conv 这样的匹配对进行 equalize,并一个则是对 conv-dwconv-conv 进行 equalize,这一部分的实现和我前一篇文章中给出的伪代码基本一样,这里简单贴出处理 conv-conv 的代码:

def _cross_layer_equalization(module1, module2):
    if type(module1) not in _supported_types or type(module2) not in _supported_types:
        raise ValueError("module type not supported:", type(module1), " ", type(module2))
    
    weight1 = module1.weight
    weight2 = module2.weight
    bias1 = module1.bias
    
    # 重排,这一部分其实可有可无,不过重排可以更好地理解代码逻辑
    if type(module2) == torch.nn.Conv2d:
        weight2 = weight2.permute(1, 0, 2, 3)  
    elif type(module2) == torch.nn.Linear:
        weight2 = weight2.permute(1, 0)
        
    # 计算两个weight的数值范围
    r1 = compute_range(weight1)  
    r2 = compute_range(weight2)
    
    # 计算缩放因子,这里包含了每个kernel的缩放系数
    s = r1 / torch.sqrt(r1 * r2)

    # 对scale进行维度扩张,方便进行broadcast
    size = [1] * weight1.ndim
    size[0] = weight1.size(0)
    s = torch.reshape(s, size)

    weight1 = weight1 * (1 / s)
    weight2 = weight2 * s

    if type(module2) == torch.nn.Conv2d:
        weight2 = weight2.permute(1, 0, 2, 3)
    elif type(module2) == torch.nn.Linear:
        weight2 = weight2.permute(1, 0)

    module1.weight = torch.nn.Parameter(weight1)
    module2.weight = torch.nn.Parameter(weight2)

    if bias1 is not None:
        s = s.view(-1,)
        bias1 = bias1 * (1 / s)
        module1.bias = torch.nn.Parameter(bias1)

这一部分相对好理解一些,同样地,我也加了一些注释,方便有需要的同学理解。

效果如何

这里我分别测试了 mobilenetv2 (把 relu6 换成 relu) 和 resnet18 的效果 (完整测试代码见 test_weight_equalize.py 文件)。

mobilenetv2

首先,我查看了 mobilenetv2 前几层可分离卷积的数值范围:

这个数值范围确实比较大,但似乎还能忍受。

然后是把 BatchNorm 和 Conv 合并后:

但这一步,这个数值范围就大的有点难以接受了,和论文里面给出的比较相似了。

做完 weight equalize 后:

数值范围拉小了很多,跟第一张图比较接近了。

resnet18

然后再看一下 resnet18 的情况。

同样地,看一下前几层卷积的数值范围:

合并 BN 后:

做完 weight equalize 后:

几乎没啥变化,所以 weight equalize 在这种没有可分离卷积的网络上面其实作用不大。

由此,可以初步得出一个结论:

  1. depthwise conv 会使得卷积的 kernel 之间在数值分布上产生较大差异;
  2. batchnorm 会使得这种差异进一步放大,因为 batchnorm 会单独对每个 input channel 都会计算均值和方差。

最后,再给大家提个醒,如果想要使用 weight equalize,先看看你的网络里面是不是使用了很多 depthwise conv,以及这些 conv 之间的激活函数是不是 ReLU、LeakyReLU、PReLU 这些,以及有没有 group conv 在里面破坏氛围,当这几点都满足后,weight equalize 才能发挥作用。

总结

这篇文章介绍了我用 pytorch fx 手撸 Weight Equalization 算法的过程中踩的几个坑,可以看出,这个算法对 mobilenetv2 这类网络确实有不小的作用,但限制也挺多,比如对激活函数有比较大的约束等。

代码方面应该还存在不少 bug。此外,我在看公司大佬实现的代码时发现,其实没必要像高通那样把 depthwise conv 单独拿出来实现,可以把可分离卷积和分组卷积都统一起来,用 conv-conv 的模式做 equalize 就可以,这样可以简便很多,效果也相差无几。这里就不方便透漏太多了。

另外,有读者问:还有 Bias Correction 呢?哪去了!

被 Weight Equalize 坑了这么久,只剩一口仙气了,谁爱 Correction 谁去

欢迎关注我的公众号:大白话AI,立志用大白话讲懂AI。