添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
沉着的热水瓶  ·  Assets, Resources and ...·  1 年前    · 
重情义的数据线  ·  RuoYi-Vue 2.3 ...·  1 年前    · 
帅气的甘蔗  ·  qtableview文字居中-掘金·  1 年前    · 
Pytorch Network Slimming

Pytorch Network Slimming

论文: Learning Efficient Convolutional Networks Through Network Slimming

作者 pytorch 实现: Eric-mingjie/network-slimming

本项目实现(以下简称PNS): Sanster/pytorch-network-slimming

本文主要从工程角度介绍 PNS 的优势和做法,关于论文本身的理论部分这里不做介绍,可以直接看下论文。

官方项目分析

原作者的项目中实现了三种网络结构: vgg , densenet , preresnet ,在网络的初始化部分,都会传入 cfg 来构建网络,对于不同网络,cfg 中的参数含义各有不同,包含网络的层数、通道数、block 的数量等。对应三种网络结构,有三份用来剪枝的脚本 vggprune.py denseprune.py resprune.py ,三份脚本的差别不大,有很多代码可以复用,主要的流程就是按照论文里描述的,先根据剪枝比例和 BN 层的 weights 计算全局 threshold,确定每层 BN 要保留的 channel 索引,再走剪枝的逻辑,获得剪枝后模型结构的 cfg 参数,并保存在 checkpoint 中,用于下次重新恢复剪枝结构。三个剪枝脚本中的主要区别在于最后的剪枝部分,以 resnet 和 vgg 为例,resnet 有 shortcuts,vgg 没有,而剪枝时涉及到 shortcuts 的 Tensor,Channel 维必须一致才能相加,所以 resnet 中会有特殊的剪枝逻辑,这一段剪枝的 代码 第一次看的时候真的是绕晕了。。。

可以看到如果想按照作者的方法把 Network Slimming 集成到已有项目中,需要做以下事情:

  1. 如果不想要作者提供的网络结构,那就要修改已有项目中构建模型部分的代码,提取出一个存放 build model 用的 cfg 参数。这一步可能就会比较麻烦,以 detectron2 中 build_resnet_backbone 的代码为例,最终的 resnet 实例是根据各种配置生成的,构建逻辑还是比较复杂的。如果要应用于检测模型,除了 backbone 以外,还要考虑 FPN、检测头等组件的修改
  2. 完成了上述 cfg 修改后,剪枝的逻辑也是要单独再写的(不同的 backbone,FPN 和 检测头),例如 github 上对 yolov3 进行剪枝的项目: YOLOv3-complete-pruning yolov3-channel-and-layer-pruning

集成 PNS 的流程

  1. 对于不同的网络结构,跑一遍 gen_schema.py (almost)自动生成剪枝的 Scheme,以 json 形式保存下来。稀疏化训练完成后使用该 schema 并指定剪枝率即可获得剪枝后的模型结构。
  2. 提供剪枝结果的保存、恢复接口,剪枝结果包括剪枝后的网络结构和 fine tune 后的网络权重。

上述流程中 无需修改已有的构建网络部分的代码 ,可以实现非侵入式的集成。以一个简单的模型为例:

import torch
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, 1)
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.conv2 = torch.nn.Conv2d(8, 8, 1)
        self.bn2 = torch.nn.BatchNorm2d(8)
    def forward(self, x):
        x = self.conv1(x)
        x_bn1_out = self.bn1(x)
        x = torch.nn.functional.relu(x_bn1_out)
        x = self.conv2(x)
        x = self.bn2(x)
        x = x + x_bn1_out
        return x

执行完 gen_schema.py 以后会生成如下的 schema,涉及到 shortcuts 的层也已经自动追踪到了, "method": "or" 表示 shortcuts 层保留下来的 channel 索引取并集,也支持 and 取交集,获得更大的剪枝率。

{
  "modules": [
      "name": "conv1",
      "prev_bn": "",
      "next_bn": "bn1"
      "name": "conv2",
      "prev_bn": "bn1",
      "next_bn": "bn2"
  "shortcuts": [
      "names": [
        "bn1",
        "bn2"
      "method": "or"
}

完成稀疏化训练后对网络进行剪枝,这里 pruning_result 表示每一层 BN、卷积保留的通道索引,pruned_model 为需要进行 fine tune 的 nn.Module 实例,fine tune 之后把这两个剪枝相关的参数保存下来,用于下次恢复剪枝结构和参数。

pruner = SlimPruner(trained_model, prune_schema_file_path)
pruning_result: List[Dict] = pruner.run(prune_ratio=0.75)
model = pruner.pruned_model
# fine tune model...
# save pruning_result and model
data = {
  "pruned_model": model.state_dict(),
  "pruning_result": pruning_result
with open("model.pth", "wb") as f:
    torch.save(data, f)

恢复剪枝结构

ckpt = torch.load("model.pth")