【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(下,结构化剪枝)

news/2024/12/23 3:04:59/

文章目录

      • 0. 前言
      • 1. 结构化剪枝 vs 非结构化剪枝
      • 2. `torch.nn.utils.prune`中的结构化剪枝方法
      • 3. PyTorch实例
        • 3.1 `random_structured`
        • 3.2 `prune.ln_structured`
      • 4. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在前文:【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)中介绍了PyTorch中的prune模型剪枝模块中的非结构化剪枝。本文将通过实例说明utils.prune中的结构化剪枝方法。

1. 结构化剪枝 vs 非结构化剪枝

1.1 非结构化剪枝的特征

非结构化剪枝是指在神经网络的权重矩阵中随机地移除一些权重,而不考虑这些权重在矩阵中的位置或它们是否构成某种结构(如一个完整的通道或过滤器)。这种剪枝方式通常会导致权重矩阵变得非常稀疏,但同时也破坏了权重的原有结构,这可能会对模型的并行计算效率产生负面影响,因为现代硬件(如GPU)在处理密集矩阵时更有效率。

1.2 结构化剪枝

相比之下,结构化剪枝则是在保持(部分)网络层的结构完整性的同时进行剪枝,即它会移除整个的神经元、通道、过滤器或其他结构单元,而不是某个完整结构中的单个权重。例如,在卷积层中,结构化剪枝可能涉及移除整个过滤器或输入/输出通道,而在全连接层中,则可能移除整行或整列的权重。

1.3 结构化剪枝的好处:
  1. **硬件友好性:**由于结构化剪枝保持了权重矩阵的结构,因此它更易于在现代硬件上实现高效的并行计算,不会像非结构化剪枝那样引入大量零元素,导致计算效率下降。
  2. **加速推理:**结构化剪枝通过移除整个的结构单元,可以直接减少模型的计算量和内存占用,从而显著加速推理过程。
  3. **易于部署:**结构化剪枝后的模型仍然保持原有的结构,这使得模型更容易被优化过的推理引擎(如TensorRT)所支持,便于在边缘设备或移动设备上部署。
  4. **更好的可解释性:**移除某些结构单元有时可以帮助理解哪些特征或信息对于模型的决策是不重要的,从而提高了模型的可解释性。

2. torch.nn.utils.prune中的结构化剪枝方法

本文将介绍2种结构化剪枝方法:

  • prune.random_structured: 随机结构化剪枝,按照给定维度移除随机通道。
  • prune.ln_structured: Ln范数结构化剪枝,沿着给定维度移除具有最低n范数的通道。

3. PyTorch实例

首先建立一个简单的模型:

python">import torch
import torch.nn as nn
from torch.nn.utils import prunetorch.manual_seed(888)
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 创建一个简单的卷积层self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)model = SimpleModel()

通过print(model.conv.weight)可以打印出权重为:

python">Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.2084,  0.0543,  0.0647],[ 0.1493,  0.2011,  0.0310]]]], requires_grad=True)
3.1 random_structured

这个方法会在指定维度dim(默认为-1)上剪枝一个随机通道:

python">prune.random_structured(model.conv, name="weight", amount=0.33)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.0000],[ 0.2107,  0.1799,  0.0000],[ 0.1887, -0.0527,  0.0000]]],[[[ 0.0799,  0.1399, -0.0000],[ 0.2013, -0.0352, -0.0000],[-0.1724, -0.3094, -0.0000]]],[[[ 0.0419,  0.2224, -0.0000],[ 0.2084,  0.0543,  0.0000],[ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

由于权重的维度为[3, 1, 3, 3],我们也可以试试在其他维度(dim=0dim=2)上进行剪枝

  • dim=0
python">prune.random_structured(model.conv, name="weight", amount=0.33,dim=0)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0000,  0.0000, -0.0000],[ 0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python">prune.random_structured(model.conv, name="weight", amount=0.33,dim=2)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.0000, -0.0000,  0.0000]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.0000, -0.0000, -0.0000]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.2084,  0.0543,  0.0647],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
3.2 prune.ln_structured

这个方法会在指定维度dim(默认为-1)上按最小n范数剪枝一个通道:

python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=-1)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.0000],[ 0.2107,  0.1799,  0.0000],[ 0.1887, -0.0527,  0.0000]]],[[[ 0.0799,  0.1399, -0.0000],[ 0.2013, -0.0352, -0.0000],[-0.1724, -0.3094, -0.0000]]],[[[ 0.0419,  0.2224, -0.0000],[ 0.2084,  0.0543,  0.0000],[ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

更改dim也是同样的效果:

  • dim=0
python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=0)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0000,  0.0000, -0.0000],[ 0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=2)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.0000,  0.0000,  0.0000],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.0000, -0.0000, -0.0000],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.0000,  0.0000,  0.0000],[ 0.1493,  0.2011,  0.0310]]]], grad_fn=<MulBackward0>)

4. 总结

至此,prune中的非结构化剪枝和结构化剪枝介绍完毕!


http://www.ppmy.cn/news/1472575.html

相关文章

Flutter【组件】点击类型表单项

简介 flutter 点击表单项组件&#xff0c;适合用户输入表单的场景。 点击表单项组件是一个用户界面元素&#xff0c;通常用于表单或设置界面中&#xff0c;以便用户可以点击它们来选择或更改某些设置或输入内容。这类组件通常由一个标签和一个可点击区域组成&#xff0c;并且…

【AI原理解析】—支持向量机原理

目录 1. 支持向量机&#xff08;SVM&#xff09;概述 2. 超平面与支持向量 3. 间隔最大化 4. 优化问题 5. 核函数 6. 总结 1. 支持向量机&#xff08;SVM&#xff09;概述 定义&#xff1a;支持向量机是一种监督学习模型&#xff0c;主要用于数据分类问题。其基本思想是…

EE trade:利弗莫尔三步建仓法

在股市投资领域&#xff0c;利弗莫尔这个名字代表着无数的智慧和经历。他的三步建仓法成为了投资者们趋之若鹜的学习对象。本文将详细解析利弗莫尔的著名买入法&#xff0c;通过分步进攻方式&#xff0c;有效掌控市场并实现盈利。 一、利弗莫尔的三步建仓法详解 利弗莫尔三步…

【介绍下SCSS的基本使用】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

SpringMVC基础详解

文章目录 一、SpringMVC简介1、什么是MVC2、MVC架构模式与三层模型的区别3、什么是SpringMVC 二、HelloWorld程序1、pom文件2、springmvc.xml3、配置web.xml文件4、html文件5、执行Controller 三、RequestMapping注解1、value属性1.1、基础使用1.2、Ant风格&#xff08;模糊匹配…

探索QCS6490目标检测AI应用开发(三):模型推理

作为《探索QCS6490目标检测AI应用开发》文章&#xff0c;紧接上一期&#xff0c;我们介绍如何在应用程序中介绍如何使用解码后的视频帧结合Yolov8n模型推理。 高通 Qualcomm AI Engine Direct 是一套能够针对高通AI应用加速的软件SDK&#xff0c;更多的内容可以访问&#xff1a…

STM32之三:中断外部中断

目录 1. 什么是中断 1.1 中断概念 1.2 中断优先级 1.3 中断嵌套 2.STM32中断 2.1 NVIC中断优先级 3 外部中断 3.1 EXTI简介 3.2 EXTI中断/事件线 3.3 EXTI功能框图 3.4 中断和事件的区别&#xff1f; 3.5 什么时候用外部中断&#xff1f; 3.怎么使用STM32中断 3.…

ISO 19110操作要求类中的/req/operation/formal-definition详细解释

/req/operation/formal-definition 要求: 每个要素操作实体必须具有一个形式定义&#xff08;formal definition&#xff09;&#xff0c;该定义应明确描述操作的行为和影响。 具体解释 定义 要素操作实体&#xff08;feature operation entity&#xff09;&#xff1a;这…