模型减肥秘籍:模型压缩技术 模型剪枝

devtools/2024/11/16 8:19:54/

教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale

相应的教程代码:datawhalechina/awesome-compression: 模型压缩的小白入门教程

模型剪枝介绍

模型剪枝是模型压缩中一种重要的技术,其基本思想是将模型中不重要的权重和分支裁剪掉,将网络结构稀疏化,进而得到参数量更小的模型,降低内存开销,使得推理速度更快,这对于需要在资源有限的设备上运行模型的应用来说尤为重要。然而,剪枝也可能导致模型性能的下降,因此需要在模型大小和性能之间找到一个平衡点。神经元在神经网络中的连接在数学上表示为权重矩阵,因此剪枝即是将权重矩阵中一部分元素变为零元素。这些剪枝后具有大量零元素的矩阵被称为稀疏矩阵,反之绝大部分元素非零的矩阵被称为稠密矩阵。剪枝过程如下图所示,目的是减去不重要的突触(Synapses)或神经元(Neurons)。

剪枝的划分

按照剪枝范围进行划分,剪枝分为局部剪枝全局剪枝

按照剪枝粒度进行划分,剪枝可分为细粒度剪枝(Fine-grained Pruning)基于模式的剪枝(Pattern-based Pruning)向量级剪枝(Vector-level Pruning)内核级剪枝(Kernel-level Pruning)通道级剪枝(Channel-level Pruning)

剪枝标准

怎么确定要减掉哪些呢?这里就涉及到剪枝的标准了。目前主流的剪枝标准有如下几种方法:

基于权重大小、 基于梯度大小、基于尺度、基于二阶。

剪枝时机总结

主要包含训练后剪枝(静态稀疏性)、训练时剪枝(动态稀疏)、训练前剪枝

一些概念上的细节大家可以结合教程进行阅读了解,下面给出相应的代码,相信知道了具体剪枝是如何实现的之后,能够加深对于剪枝的理解。

代码实现:

不同粒度的剪枝实现:

细粒度剪枝
def fine_grained_prune(tensor: torch.Tensor, threshold  : float) -> torch.Tensor:"""创建一个掩码张量,指示哪些权重不应被剪枝(应保持非零)。:param tensor: 输入张量,包含需要剪枝的权重。:param threshold: 阈值,用于判断权重的大小。:return: 剪枝后的张量。"""mask = torch.gt(tensor, threshold)"""torch.gt(tensor, threshold):这个函数会返回一个与tensor形状相同的布尔张量(掩码),其中每个元素的值为True或False,取决于对应的tensor元素是否大于threshold。"""tensor.mul_(mask)return tensor

使用mask掩码矩阵而不是for循环去遍历,可以加快运行速度。

基于模式的剪枝

这里以NVIDIA 4:2为例,创建一个patterns,如下图所示,由于是2:4,即从4个中取出2个置为0,可以算出一共有6种不同的模式;

然后将weight matrix变换成nx4的格式方便与pattern进行矩阵运算,运算后的结果为nx6的矩阵,在n的维度上进行argmax取得最大的索引(索引对应pattern),然后将索引对应的pattern值填充到mask中。

from itertools import permutationsdef reshape_1d(tensor, m):# 转换成列为m的格式,若不能整除m则填充0if tensor.shape[1] % m > 0:mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)mat[:, : tensor.shape[1]] = tensorreturn mat.view(-1, m)else:return tensor.view(-1, m)def compute_valid_1d_patterns(m, n):# 创建一个长度为m的全零张量。patterns = torch.zeros(m)# 将前n个元素设置为1,形成一个包含n个1和m-n个0的模式。patterns[:n] = 1# 计算所有可能的排列#     permutations:用于生成所有可能的排列组合。# set:用于去重,确保返回的模式是唯一的。valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternsdef compute_mask(tensor, m, n):# 计算所有可能的模式patterns = compute_valid_1d_patterns(m,n)# 找到m:n最好的模式mask = torch.IntTensor(tensor.shape).fill_(1).view(-1,m)mat = reshape_1d(tensor, m)# torch.matmul(mat.abs(), patterns.t()):计算输入张量的绝对值与模式的点积,得到每个模式的得分。# torch.argmax(..., dim=1):找到每行的最大得分对应的索引,表示最佳模式。pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)mask[:] = patterns[pmax[:]]mask = mask.view(tensor.shape)return maskdef pattern_pruning(tensor, m, n):mask = compute_mask(weight, m, n)tensor.mul_(mask)return tensorpruned_weight = pattern_pruning(weight, 4, 2)
plot_tensor(pruned_weight, '剪枝后weight')
向量级别剪纸:
# 剪枝某个点所在的行与列
def vector_pruning(weight, point):row, col = pointprune_weight = weight.clone()prune_weight[row, :] = 0prune_weight[:, col] = 0return prune_weight
point = (1, 1)
prune_weight = vector_pruning(weight, point)
plot_tensor(prune_weight, '向量级剪枝后weight')

卷积核级别的剪枝
def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个kernel的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(-2, -1), keepdim=True)# 计算L2范数的分位数阈值threshold = torch.quantile(l2_norm, percentile)# 创建掩码,保留L2范数大于阈值的kernelmask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method)  # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Kernel级别剪枝', vis=True)

Filter级别的剪枝
def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个Filter的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(1, 2, 3), keepdim=True)threshold = torch.quantile(l2_norm, percentile)mask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method)  # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Filter级别剪枝', vis=True)

Channel级别的剪枝

def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个channel的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(0, 2, 3), keepdim=True)threshold = torch.quantile(l2_norm, percentile)mask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method)  # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Channel级别剪枝', vis=True)

后面三种的主要区别在于计算L2范数时,是在哪些维度上进行计算。

不同剪枝标准的实现:

基于L1权重大小的剪枝:
@torch.no_grad()
def prune_l1(weight, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量(需要裁剪掉的权重数量)num_zeros = round(num_elements * percentile)# 计算weight的重要性importance = weight.abs()# 计算裁剪阈值# importance.view(-1) 将权重的重要性张量展平为一维。# kthvalue(num_zeros) 方法返回第 num_zeros 小的值(即裁剪阈值),这个值将用于确定哪些权重将被裁剪。threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
基于L2权重大小的剪枝:
@torch.no_grad()
def prune_l2(weight, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量num_zeros = round(num_elements * percentile)# 计算weight的重要性(使用L2范数,即各元素的平方)importance = weight.pow(2)# 计算裁剪阈值threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
基于梯度幅度的剪枝
# 修剪局部模型权重,传入某一层的权重
@torch.no_grad()
def gradient_magnitude_pruning(weight, gradient, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量num_zeros = round(num_elements * percentile)# 计算weight的重要性(使用L1范数)importance = gradient.abs()# 计算裁剪阈值threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
# 修剪整个模型的权重,传入整个模型
def gradient_magnitude_pruning(model, percentile):for name, param in model.named_parameters():if 'weight' in name:mask = torch.abs(gradients[name]) >= percentileparam.data *= mask.float()

注意这里修剪整个模型的权重的时候,是进行的绝对值的比较,即直接比较模型权重和percentile的大小,而不是计算出对应的threshold,个人觉得是因为不同层的权重之间可能本身存在大小差异,放在一起计算出相应位置的权重作为threshold并不是很合理。


http://www.ppmy.cn/devtools/134385.html

相关文章

Briefly unavailable for scheduled maintenance. Check back in a minute.

访问wordpress网站时出现“Briefly unavailable for scheduled maintenance. Check back in a minute.”时,不要着急,不要害怕,这不是什么多大的问题。这表明wordpress的程序或wordpress使用到的插件正在升级,这是在自动升级&…

〔 MySQL 〕数据类型

目录 1.数据类型分类 2 数值类型 2.1 tinyint类型 2.2 bit类型 2.3 小数类型 2.3.1 float 2.3.2 decimal 3 字符串类型 3.1 char 3.2 varchar 3.3 char和varchar比较 4 日期和时间类型 5 enum和set mysql表中建立属性列: 列名称,类型在后 n…

2024年11月15日Github流行趋势

项目名称:MinerU 项目维护者:myhloli, dt-yy, Focusshang, drunkpig, papayalove等项目介绍:一站式开源高质量数据提取工具,支持从PDF、网页和多格式电子书中提取数据。项目star数:15,059项目fork数:1,105 …

对话 OpenCV 之父 Gary Bradski:灾难性遗忘和持续学习是尚未解决的两大挑战 | Open AGI Forum

作者 | Annie Xu 采访、责编 | Eric Wang 出品丨GOSIM 开源创新汇 Gary Bradski,旺盛的好奇心、敢于冒险的勇气、独到的商业视角让他成为计算视觉、自动驾驶领域举重若轻的奠基者。 Gary 曾加入 Stanley 的团队,帮助其赢得 2005 年美国穿越沙漠 DA…

【Python进阶】自动化办公超能力:利用Python自动化Excel、Word任务

1、Python支持办公自动化的关键库介绍 import pandas as pd # 加载数据 df pd.read_csv(data.csv) # 对数据进行清洗和分析 df_cleaned df.dropna() # 删除缺失值 grouped_data df_cleaned.groupby(category).sum() # 按类别进行分组求和openpyxl, xlrd, xlwt, xlsxwrite…

C语言之MakeFile

Makefile 的引入是为解决多文件项目中手动编译繁琐易错、缺乏自动化构建、项目管理维护困难以及跨平台构建不便等问题&#xff0c;实现自动化、规范化的项目构建与管理 MakeFile 简单的来说,MakeFile就是编写编译命令的文件 文件编写格式 目标:依赖文件列表 <Tab>命令列表…

在 Oracle Linux 8.9 上安装Oracle Database 23ai 23.5

在 Oracle Linux 8.9 上安装Oracle Database 23ai 23.5 1. 安装 Oracle Database 23ai2. 连接 Oracle Database 23c3. 重启启动后&#xff0c;手动启动数据库4. 重启启动后&#xff0c;手动启动 Listener5. 手动启动 Pluggable Database6. 自动启动 Pluggable Database7. 设置开…

24.11.15 Vue3

let newJson new Proxy(myJson,{get(target,prop){console.log(在读取${prop}属性);return target[prop];},set(target,prop,val){console.log(在设置${prop}属性值为${val});if(prop"name"){document.getElementById("myTitle").innerHTML val;}if(prop…