教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale
相应的教程代码:datawhalechina/awesome-compression: 模型压缩的小白入门教程
模型剪枝介绍
模型剪枝是模型压缩中一种重要的技术,其基本思想是将模型中不重要的权重和分支裁剪掉,将网络结构稀疏化,进而得到参数量更小的模型,降低内存开销,使得推理速度更快,这对于需要在资源有限的设备上运行模型的应用来说尤为重要。然而,剪枝也可能导致模型性能的下降,因此需要在模型大小和性能之间找到一个平衡点。神经元在神经网络中的连接在数学上表示为权重矩阵,因此剪枝即是将权重矩阵中一部分元素变为零元素。这些剪枝后具有大量零元素的矩阵被称为稀疏矩阵,反之绝大部分元素非零的矩阵被称为稠密矩阵。剪枝过程如下图所示,目的是减去不重要的突触(Synapses)或神经元(Neurons)。
剪枝的划分
按照剪枝粒度进行划分,剪枝可分为细粒度剪枝(Fine-grained Pruning)、基于模式的剪枝(Pattern-based Pruning)、向量级剪枝(Vector-level Pruning)、内核级剪枝(Kernel-level Pruning)与通道级剪枝(Channel-level Pruning)。
剪枝标准
怎么确定要减掉哪些呢?这里就涉及到剪枝的标准了。目前主流的剪枝标准有如下几种方法:
基于权重大小、 基于梯度大小、基于尺度、基于二阶。
剪枝时机总结
主要包含训练后剪枝(静态稀疏性)、训练时剪枝(动态稀疏)、训练前剪枝。
一些概念上的细节大家可以结合教程进行阅读了解,下面给出相应的代码,相信知道了具体剪枝是如何实现的之后,能够加深对于剪枝的理解。
代码实现:
不同粒度的剪枝实现:
细粒度剪枝:
def fine_grained_prune(tensor: torch.Tensor, threshold : float) -> torch.Tensor:"""创建一个掩码张量,指示哪些权重不应被剪枝(应保持非零)。:param tensor: 输入张量,包含需要剪枝的权重。:param threshold: 阈值,用于判断权重的大小。:return: 剪枝后的张量。"""mask = torch.gt(tensor, threshold)"""torch.gt(tensor, threshold):这个函数会返回一个与tensor形状相同的布尔张量(掩码),其中每个元素的值为True或False,取决于对应的tensor元素是否大于threshold。"""tensor.mul_(mask)return tensor
使用mask掩码矩阵而不是for循环去遍历,可以加快运行速度。
基于模式的剪枝:
这里以NVIDIA 4:2为例,创建一个patterns,如下图所示,由于是2:4,即从4个中取出2个置为0,可以算出一共有6种不同的模式;
然后将weight matrix变换成nx4的格式方便与pattern进行矩阵运算,运算后的结果为nx6的矩阵,在n的维度上进行argmax取得最大的索引(索引对应pattern),然后将索引对应的pattern值填充到mask中。
from itertools import permutationsdef reshape_1d(tensor, m):# 转换成列为m的格式,若不能整除m则填充0if tensor.shape[1] % m > 0:mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)mat[:, : tensor.shape[1]] = tensorreturn mat.view(-1, m)else:return tensor.view(-1, m)def compute_valid_1d_patterns(m, n):# 创建一个长度为m的全零张量。patterns = torch.zeros(m)# 将前n个元素设置为1,形成一个包含n个1和m-n个0的模式。patterns[:n] = 1# 计算所有可能的排列# permutations:用于生成所有可能的排列组合。# set:用于去重,确保返回的模式是唯一的。valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternsdef compute_mask(tensor, m, n):# 计算所有可能的模式patterns = compute_valid_1d_patterns(m,n)# 找到m:n最好的模式mask = torch.IntTensor(tensor.shape).fill_(1).view(-1,m)mat = reshape_1d(tensor, m)# torch.matmul(mat.abs(), patterns.t()):计算输入张量的绝对值与模式的点积,得到每个模式的得分。# torch.argmax(..., dim=1):找到每行的最大得分对应的索引,表示最佳模式。pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)mask[:] = patterns[pmax[:]]mask = mask.view(tensor.shape)return maskdef pattern_pruning(tensor, m, n):mask = compute_mask(weight, m, n)tensor.mul_(mask)return tensorpruned_weight = pattern_pruning(weight, 4, 2)
plot_tensor(pruned_weight, '剪枝后weight')
向量级别剪纸:
# 剪枝某个点所在的行与列
def vector_pruning(weight, point):row, col = pointprune_weight = weight.clone()prune_weight[row, :] = 0prune_weight[:, col] = 0return prune_weight
point = (1, 1)
prune_weight = vector_pruning(weight, point)
plot_tensor(prune_weight, '向量级剪枝后weight')
卷积核级别的剪枝:
def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个kernel的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(-2, -1), keepdim=True)# 计算L2范数的分位数阈值threshold = torch.quantile(l2_norm, percentile)# 创建掩码,保留L2范数大于阈值的kernelmask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method) # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Kernel级别剪枝', vis=True)
Filter级别的剪枝:
def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个Filter的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(1, 2, 3), keepdim=True)threshold = torch.quantile(l2_norm, percentile)mask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method) # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Filter级别剪枝', vis=True)
Channel级别的剪枝:
def prune_conv_layer(conv_layer, prune_method,title="", percentile=0.2, vis=True):prune_layer = conv_layer.clone()l2_norm = Nonemask = None# 计算每个channel的L2范数l2_norm = torch.norm(prune_layer, p=2, dim=(0, 2, 3), keepdim=True)threshold = torch.quantile(l2_norm, percentile)mask = l2_norm > thresholdprune_layer = prune_layer * mask.float()visualize_tensor(prune_layer,title=prune_method) # 使用PyTorch创建一个张量
tensor = torch.rand((3, 10, 4, 5))# 调用函数进行剪枝pruned_tensor = prune_conv_layer(tensor, 'Channel级别剪枝', vis=True)
后面三种的主要区别在于计算L2范数时,是在哪些维度上进行计算。
不同剪枝标准的实现:
基于L1权重大小的剪枝:
@torch.no_grad()
def prune_l1(weight, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量(需要裁剪掉的权重数量)num_zeros = round(num_elements * percentile)# 计算weight的重要性importance = weight.abs()# 计算裁剪阈值# importance.view(-1) 将权重的重要性张量展平为一维。# kthvalue(num_zeros) 方法返回第 num_zeros 小的值(即裁剪阈值),这个值将用于确定哪些权重将被裁剪。threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
基于L2权重大小的剪枝:
@torch.no_grad()
def prune_l2(weight, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量num_zeros = round(num_elements * percentile)# 计算weight的重要性(使用L2范数,即各元素的平方)importance = weight.pow(2)# 计算裁剪阈值threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
基于梯度幅度的剪枝:
# 修剪局部模型权重,传入某一层的权重
@torch.no_grad()
def gradient_magnitude_pruning(weight, gradient, percentile=0.5):num_elements = weight.numel()# 计算值为0的数量num_zeros = round(num_elements * percentile)# 计算weight的重要性(使用L1范数)importance = gradient.abs()# 计算裁剪阈值threshold = importance.view(-1).kthvalue(num_zeros).values# 计算maskmask = torch.gt(importance, threshold)# 计算mask后的weightweight.mul_(mask)return weight
# 修剪整个模型的权重,传入整个模型
def gradient_magnitude_pruning(model, percentile):for name, param in model.named_parameters():if 'weight' in name:mask = torch.abs(gradients[name]) >= percentileparam.data *= mask.float()
注意这里修剪整个模型的权重的时候,是进行的绝对值的比较,即直接比较模型权重和percentile的大小,而不是计算出对应的threshold,个人觉得是因为不同层的权重之间可能本身存在大小差异,放在一起计算出相应位置的权重作为threshold并不是很合理。