【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(下,结构化剪枝)

server/2024/11/24 21:28:57/

文章目录

      • 0. 前言
      • 1. `torch.nn.utils.prune`中的结构化剪枝方法
      • 2. PyTorch实例
        • 2.1 `random_structured`
        • 2.2 `prune.ln_structured`
      • 3. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在前文:【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)中介绍了PyTorch中的prune模型剪枝模块中的非结构化剪枝。本文将通过实例说明utils.prune中的结构化剪枝方法。

1. torch.nn.utils.prune中的结构化剪枝方法

  • prune.random_structured: 随机结构化剪枝,按照给定维度移除随机通道。
  • prune.ln_structured: Ln范数结构化剪枝,沿着给定维度移除具有最低n范数的通道。

2. PyTorch实例

首先建立一个简单的模型:

python">import torch
import torch.nn as nn
from torch.nn.utils import prunetorch.manual_seed(888)
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 创建一个简单的卷积层self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)model = SimpleModel()

通过print(model.conv.weight)可以打印出权重为:

python">Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.2084,  0.0543,  0.0647],[ 0.1493,  0.2011,  0.0310]]]], requires_grad=True)
2.1 random_structured

这个方法会在指定维度dim(默认为-1)上剪枝一个随机通道:

python">prune.random_structured(model.conv, name="weight", amount=0.33)
print("Weight after RandomStructured pruning (33%):")
print(model.conv.weight)

输出为:

python">Weight after RandomStructured pruning (33%):
tensor([[[[-0.3017,  0.1290, -0.0000],[ 0.2107,  0.1799,  0.0000],[ 0.1887, -0.0527,  0.0000]]],[[[ 0.0799,  0.1399, -0.0000],[ 0.2013, -0.0352, -0.0000],[-0.1724, -0.3094, -0.0000]]],[[[ 0.0419,  0.2224, -0.0000],[ 0.2084,  0.0543,  0.0000],[ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

由于权重的维度为[3, 1, 3, 3],我们也可以试试在其他维度(dim=0dim=2)上进行剪枝

  • dim=0
python">prune.random_structured(model.conv, name="weight", amount=0.33,dim=0)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0000,  0.0000, -0.0000],[ 0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python">prune.random_structured(model.conv, name="weight", amount=0.33,dim=2)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.0000, -0.0000,  0.0000]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.0000, -0.0000, -0.0000]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.2084,  0.0543,  0.0647],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
2.2 prune.ln_structured

这个方法会在指定维度dim(默认为-1)上按最小n范数剪枝一个通道:

python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=-1)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.0000],[ 0.2107,  0.1799,  0.0000],[ 0.1887, -0.0527,  0.0000]]],[[[ 0.0799,  0.1399, -0.0000],[ 0.2013, -0.0352, -0.0000],[-0.1724, -0.3094, -0.0000]]],[[[ 0.0419,  0.2224, -0.0000],[ 0.2084,  0.0543,  0.0000],[ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

更改dim也是同样的效果:

  • dim=0
python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=0)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.2013, -0.0352, -0.1027],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0000,  0.0000, -0.0000],[ 0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python">prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=2)
print(model.conv.weight)

输出为:

python">tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.0000,  0.0000,  0.0000],[ 0.1887, -0.0527,  0.1403]]],[[[ 0.0799,  0.1399, -0.0084],[ 0.0000, -0.0000, -0.0000],[-0.1724, -0.3094, -0.2382]]],[[[ 0.0419,  0.2224, -0.1558],[ 0.0000,  0.0000,  0.0000],[ 0.1493,  0.2011,  0.0310]]]], grad_fn=<MulBackward0>)

3. 总结

至此,prune中的非结构化剪枝和结构化剪枝介绍完毕!


http://www.ppmy.cn/server/51581.html

相关文章

Java启动jar设置内存分配详解

在微服务架构越来越盛行的情况下&#xff0c;我们通常一个系统都会拆成很多个小的服务&#xff0c;但是最终部署的时候又因为没有那么多服务器只能把多个服务部署在同一台服务器上&#xff0c;这个时候问题就来了&#xff0c;服务器内存不够&#xff0c;这个时候我们就需要对每…

Lua 面向对象编程

Lua 面向对象编程 Lua 是一种轻量级的编程语言,通常用于嵌入应用程序中,提供灵活的扩展和定制功能。尽管 Lua 本身是一种过程式语言,但它提供了强大的元机制,允许开发者实现面向对象的编程范式。本文将探讨 Lua 中的面向对象编程(OOP)概念、实现方式以及最佳实践。 面向…

大型语言模型(LLM)和多模态大型语言模型(MLLM)的越狱攻击

随着大型语言模型&#xff08;LLMs&#xff09;的快速发展&#xff0c;它们在各种任务上表现出了卓越的性能&#xff0c;有效地遵循指令以满足多样化的用户需求。然而&#xff0c;随着这些模型遵循指令的能力不断提升&#xff0c;它们也越来越成为对抗性攻击的目标&#xff0c;…

STM32G070休眠例程-STOP模式

一、简介 主控是STM32G070&#xff0c;在低功耗休眠模式时采用Stop0模式&#xff0c;通过外部中断唤醒&#xff0c;唤醒之后&#xff0c;即可开启对应的功能输出&#xff0c;另外程序中设计有看门狗8S溢出&#xff0c;这个采用RTC定时6S周期唤醒去喂狗&#xff0c;RTC唤醒喂狗的…

redis面试总结

redis的数据类型&#xff1f; string字符串&#xff1a;类似于java中Map<String,String>。存储字符串、JSON数据、验证码等。 Hash字典&#xff1a;类似java中Map<String, Map<Spring,String>>。比较适合存储对象数据。 List列表&#xff1a;类似java中Ma…

Java Matcher类方法深度剖析:查找和匹配、索引方法

1. 引言 在Java中,正则表达式是处理字符串的强大工具,而java.util.regex包中的Matcher类则是实现这一功能的核心。对于Java工程师而言,熟练掌握Matcher类的使用方法,无疑能够极大地提升字符串处理的效率和准确性。本文将对Matcher类的方法进行深度讲解,并按照查找和匹配方…

代码随想录算法训练营DAY39|62.不同路径、63. 不同路径 II

忙。。后两题先跳过 62.不同路径 题目链接&#xff1a;62.不同路径 class Solution(object):def uniquePaths(self, m, n):""":type m: int:type n: int:rtype: int"""dp[[0 for a in range(n)] for b in range(m)]print(dp)dp[0][0]1for i i…

邂逅Three.js探秘图形世界之美

可能了解过three.js等大型的3D 图形库同学都知道啊&#xff0c;学习3D技术都需要有图形学、线性代数、webgl等基础知识&#xff0c;以前读书学的线性代数足够扎实的话听这节课也会更容易理解&#xff0c;这是shader课程&#xff0c;希望能帮助你理解着色器&#xff0c;也面向第…