Pytorch剪枝api测试和结果

news/2024/11/19 0:41:38/

Pytorch 官方给出的prune接口

下面是基于prune的接口进行剪枝的方法步骤

1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:

  • RandomUnstructured
  • L1Unstructured
  • RandomStructured
  • LnStructured
  • CustomFromMask
    ps:非结构性剪枝不会给剪枝后模型的速度带来提升。

2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数

from torch.nn.utils import prune 
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0123
'''

3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存

prune.remove(module, 'weight')
print(list(module.named_parameters()))

5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。

对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()

import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''

然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝

prune全局非结构性剪枝测试结果

# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg')  # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model

base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323

case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364

case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859

case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’

case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578

case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322

prune结构性剪枝测试结果

通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)

结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。

结论:pytorch的官方接口并不能直接使用


http://www.ppmy.cn/news/51053.html

相关文章

geometric distribution and exponential distribution(几何分布和指数分布)

几何分布 分布函数均值和方差意义 表示经过k次实验才第一次得到正确的实验结果 比如抛硬币得到正面的需要抛的次数 指数分布 分布函数均值和方差意义 表示经过一段x之后,某件事第一次发生 比如经过x时间之后,公交车来的概率 比如餐厅从开业到第一个客人…

个人知识库(持续更新中)

打造一个属于自己的知识库 为什么会有这个知识库会记录什么内容基础知识Java核心Java WebMySQL 中间件&工具项目代码资源仿牛客社区Web开发华夏ERP软件 视频资源代码之外持续更新中… 为什么会有这个知识库 作为羊哥的死忠粉,当他谈到个人知识库这个东西的时候…

全栈成长-python学习笔记之数据类型

python数据类型 数字类型 类型类型转换整型 intint() 字符串类型转换 浮点型保留整数 int(3.14)3 int(3.94)3浮点型 floatfloat() #####字符串类型 类型类型转换字符串 strstr() 将其他数据类型转为字符串 布尔类型与空类型 布尔类型 类型类型转换布尔型 boolbool()将其他…

Git 储藏(stash)详解

文章目录 单次储藏及应用多次储藏及应用对特定范围文件进行储藏 此文在阅读前需要有一定的git命令基础,若基础尚未掌握,建议先阅读这篇文章Git命令播报详版 在平常的工作中,当我们在单独拉取的功能分支中进行开发时,遇到线上出现b…

一篇文章看懂MySQL的多表连接(包含左/右/全外连接)

MySQL的多表查询 这是第二次学习多表查询,关于左右连接还是不是很熟悉,因此重新看一下。小目标:一篇文章看懂多表查询!! 这篇博客是跟着宋红康老师学习的,点击此处查看视频,关于数据库我放在了…

spring eurake中使用IP注册

在开发spring cloud的时候遇到一个很奇葩的问题,就是服务向spring eureka中注册实例的时候使用的是机器名,然后出现localhost、xxx.xx等这样的内容,如下图: eureka.instance.perferIpAddresstrue 我不知道这朋友用的什么spring c…

2023年Spark大数据处理讲课笔记

文章目录 一、Scala语言基础二、Spark基础三、Spark RDD弹性分布式数据集 一、Scala语言基础 Spark大数据处理讲课笔记1.1 搭建Scala开发环境Spark大数据处理讲课笔记1.2 Scala变量与数据类型Spark大数据处理讲课笔记1.3 使用Scala集成开发环境Spark大数据处理讲课笔记1.4 掌握…

Java线程两种创建方式

Java中多线程的创建方式 第一种:继承Thread类的方式 1.声明Thread的子类继承Thread,在子类中重写run()方法,这东西我们也叫它线程任务。2.创建继承了Thread类的子类的对象,调用方法启动多线程。 演示: 第一步&#…