Pytorch Network Slimming
论文: Learning Efficient Convolutional Networks Through Network Slimming
作者 pytorch 实现: Eric-mingjie/network-slimming
本项目实现(以下简称PNS): Sanster/pytorch-network-slimming
本文主要从工程角度介绍 PNS 的优势和做法,关于论文本身的理论部分这里不做介绍,可以直接看下论文。
官方项目分析
原作者的项目中实现了三种网络结构: vgg , densenet , preresnet ,在网络的初始化部分,都会传入 cfg 来构建网络,对于不同网络,cfg 中的参数含义各有不同,包含网络的层数、通道数、block 的数量等。对应三种网络结构,有三份用来剪枝的脚本 vggprune.py , denseprune.py 和 resprune.py ,三份脚本的差别不大,有很多代码可以复用,主要的流程就是按照论文里描述的,先根据剪枝比例和 BN 层的 weights 计算全局 threshold,确定每层 BN 要保留的 channel 索引,再走剪枝的逻辑,获得剪枝后模型结构的 cfg 参数,并保存在 checkpoint 中,用于下次重新恢复剪枝结构。三个剪枝脚本中的主要区别在于最后的剪枝部分,以 resnet 和 vgg 为例,resnet 有 shortcuts,vgg 没有,而剪枝时涉及到 shortcuts 的 Tensor,Channel 维必须一致才能相加,所以 resnet 中会有特殊的剪枝逻辑,这一段剪枝的 代码 第一次看的时候真的是绕晕了。。。
可以看到如果想按照作者的方法把 Network Slimming 集成到已有项目中,需要做以下事情:
- 如果不想要作者提供的网络结构,那就要修改已有项目中构建模型部分的代码,提取出一个存放 build model 用的 cfg 参数。这一步可能就会比较麻烦,以 detectron2 中 build_resnet_backbone 的代码为例,最终的 resnet 实例是根据各种配置生成的,构建逻辑还是比较复杂的。如果要应用于检测模型,除了 backbone 以外,还要考虑 FPN、检测头等组件的修改
- 完成了上述 cfg 修改后,剪枝的逻辑也是要单独再写的(不同的 backbone,FPN 和 检测头),例如 github 上对 yolov3 进行剪枝的项目: YOLOv3-complete-pruning 、 yolov3-channel-and-layer-pruning
集成 PNS 的流程
- 对于不同的网络结构,跑一遍 gen_schema.py (almost)自动生成剪枝的 Scheme,以 json 形式保存下来。稀疏化训练完成后使用该 schema 并指定剪枝率即可获得剪枝后的模型结构。
- 提供剪枝结果的保存、恢复接口,剪枝结果包括剪枝后的网络结构和 fine tune 后的网络权重。
上述流程中 无需修改已有的构建网络部分的代码 ,可以实现非侵入式的集成。以一个简单的模型为例:
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 8, 1)
self.bn1 = torch.nn.BatchNorm2d(8)
self.conv2 = torch.nn.Conv2d(8, 8, 1)
self.bn2 = torch.nn.BatchNorm2d(8)
def forward(self, x):
x = self.conv1(x)
x_bn1_out = self.bn1(x)
x = torch.nn.functional.relu(x_bn1_out)
x = self.conv2(x)
x = self.bn2(x)
x = x + x_bn1_out
return x
执行完
gen_schema.py
以后会生成如下的 schema,涉及到 shortcuts 的层也已经自动追踪到了,
"method": "or"
表示 shortcuts 层保留下来的 channel 索引取并集,也支持
and
取交集,获得更大的剪枝率。
{
"modules": [
"name": "conv1",
"prev_bn": "",
"next_bn": "bn1"
"name": "conv2",
"prev_bn": "bn1",
"next_bn": "bn2"
"shortcuts": [
"names": [
"bn1",
"bn2"
"method": "or"
}
完成稀疏化训练后对网络进行剪枝,这里
pruning_result
表示每一层 BN、卷积保留的通道索引,pruned_model 为需要进行 fine tune 的 nn.Module 实例,fine tune 之后把这两个剪枝相关的参数保存下来,用于下次恢复剪枝结构和参数。
pruner = SlimPruner(trained_model, prune_schema_file_path)
pruning_result: List[Dict] = pruner.run(prune_ratio=0.75)
model = pruner.pruned_model
# fine tune model...
# save pruning_result and model
data = {
"pruned_model": model.state_dict(),
"pruning_result": pruning_result
with open("model.pth", "wb") as f:
torch.save(data, f)
恢复剪枝结构
ckpt = torch.load("model.pth")