深度学习blog-剪枝和知识蒸馏

devtools/2025/1/17 17:08:54/

深度学习网络模型从卷积层到全连接层存在着大量冗余的参数,大量神经元激活值趋近于0,将这些神经元去除后可以表现出同样的模型表达能力,这种情况被称为过参数化。因此需要一些技术手段减少模型的复杂性,去除一些不重要的参数和连接,从而提高模型在推理阶段的效率,减少存储需求,同时可能还能够降低过拟合的风险。

常用的模型的压缩和轻量化加速技术有:

  1. 权重剪枝通过删除神经网络中冗余的权重来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断权重的重要性,然后将不重要的权重设置为零或删除。
  2. 模型量化:将神经网络中的权重和激活值从浮点数转换为低精度的整数表示,从而减少模型的存储空间和计算量。
  3. 知识蒸馏(Knowledge Distillation):这是一种特殊的模型蒸馏技术,其中教师模型和学生模型具有相同的架构,但参数不同。通过让学生模型学习教师模型的输出,可以实现模型的压缩和加速。
  4. 知识提炼(Knowledge Carving):选择性地从教师模型中抽取部分子结构用于构建学生模型。
  5. 网络剪枝(Network Pruning):通过删除神经网络中冗余的神经元或连接来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断神经元或连接的重要性,然后将不重要的神经元或连接删除。
  6. 低秩分解(Low-Rank Factorization):将神经网络中的权重矩阵分解为两个低秩矩阵的乘积,从而减少模型的存储空间和计算量。这种方法可以应用于卷积层和全连接层等不同类型的神经网络层。
  7. 结构搜索(Neural Architecture Search):通过自动搜索最优的神经网络结构来实现模型的压缩和加速。这种方法可以根据特定任务的需求来定制适合的神经网络结构。

剪枝(Pruning)深度学习和神经网络中常用的一种模型压缩技术。

1. 剪枝的背景
深度学习模型通常由大量的参数组成,尤其是在深层神经网络中,这些参数使得模型能力强大,但也导致计算和存储成本高。为了在工业应用中将模型部署到资源有限的设备上,剪枝成为了重要的研究方向。

模型剪枝主要分为结构化剪枝非结构化剪枝非结构化剪枝去除不重要的神经元,相应地,被剪除的神经元和其他神经元之间的连接在计算时会被忽略。由于剪枝后的模型通常很稀疏,并且破坏了原有模型的结构,所以这类方法被称为非结构化剪枝

2. 剪枝粒度分类
 

  •  细粒度剪枝(fine-grained):即对连接或者神经元进行剪枝,它是粒度最小的剪枝
  • 向量剪枝(vector-level):它相对于细粒度剪枝粒度更大,属于对卷积核内部(intra-kernel)的剪枝
  • 剪枝(kernel-level):即去除某个卷积核,它将丢弃对输入通道中对应计算通道的响应。
  • 滤波器剪枝(Filter-level):对整个卷积核组进行剪枝,会造成推理过程中输出特征通道数的改变。

3. 剪枝的流程
剪枝的基本流程通常包括以下几个步骤:

训练:首先对神经网络进行训练,直到达到满意的精度。
评估重要性:使用特定的标准(如权重的绝对值、梯度等)来评估每个参数或神经元的重要性。
剪枝:根据重要性评估结果,去除一些参数或神经元。
微调(Fine-tuning):对剪枝后的模型进行再训练,以恢复模型的性能。

4. 剪枝的优缺点
优点:
减少计算复杂度:剪枝后,模型推理速度更快。
降低存储需求:模型所需的存储空间减少。
提高模型泛化能力:可能减少过拟合,并提高在新数据上的表现。
缺点:
剪枝带来的性能损失:不当的剪枝可能导致模型精度下降。
额外的复杂性:剪枝和微调过程增加了模型训练的复杂性。


5. 实际应用
剪枝技术已广泛应用于移动设备、边缘计算和实时应用中,例如图像识别、自然语言处理等任务,很多现代深度学习框架(如TensorFlow、PyTorch)都有包含剪枝的相关工具和库。

知识蒸馏(Knowledge Distillation)

是一种模型压缩技术,旨在将大型深度学习模型(通常称为“教师模型”)中的知识转移到较小的模型(称为“学生模型”)中。这种技术在计算资源有限的环境下尤为重要,因为它可以提高推理速度并减少模型的存储需求。
知识蒸馏主要包括以下几个步骤:

训练教师模型:首先,训练一个大型性能良好的教师模型。此模型通常在大规模数据集上经过充分训练,能够非常有效地捕捉数据中的复杂模式。

生成软标签:教师模型在训练集上预测的输出称为“软标签”。软标签包含了每个类别的概率分布,相比于传统的硬标签(one-hot编码),它保留了更多的信息。

训练学生模型:使用教师模型的预测(软标签)来训练更小的学生模型。在训练过程中,学生模型将学习教师模型的预测分布,而不仅仅是目标类别。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F
from ultralytics import YOLO# 假设教师模型和学生模型已经定义
teacher_model = YOLO('yolo11n.pt').to(device)  
teacher_model.eval()  # 设置为评估模式
student_model = StudentModel().to(device)# 知识蒸馏损失
def distillation_loss(y_true, y_pred, teacher_output, T, alpha):"""distillation loss = alpha * cross_entropy_loss(y_true, y_pred) + (1 - alpha) * KL_divergence(softmax(teacher_output / T), softmax(y_pred / T))"""loss_ce = F.cross_entropy(y_pred, y_true)loss_kl = F.kl_div(F.log_softmax(y_pred / T, dim=1), F.softmax(teacher_output / T, dim=1), reduction='batchmean')return alpha * loss_ce + (1 - alpha) * loss_kl# 训练学生模型
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)# train
for epoch in range(num_epochs):student_model.train()for images, labels in train_loader:optimizer.zero_grad()# 使用教师模型对训练数据生成软标签teacher_output = teacher_model(images)student_output = student_model(images)loss = distillation_loss(labels, student_output, teacher_output, T=2, alpha=0.7)loss.backward()optimizer. Step()

分类:

离线蒸馏,大型教师模型蒸馏前在训练样本训练;教师模型以logits或中间特征的形式提取知识,将其在蒸馏过程中指导学生模型的训练。教师的结构是预定义的,很少关注教师模型的结构及其与学生模型的关系。例如上面的蒸馏,使用预训练的权重作为教师模型。

在线蒸馏:教师模型和学生模型同步更新,而整个知识蒸馏框架都是端到端可训练的。

自蒸馏:教师和学生模型使用相同的网络,这可以看作是在线蒸馏的一个特例。

蒸馏算法,有基于注意力蒸馏,基于图蒸馏,基于生成对抗网络GAN蒸馏,量化蒸馏等等。


http://www.ppmy.cn/devtools/151316.html

相关文章

python mysql库的三个库mysqlclient mysql-connector-python pymysql如何选择,他们之间的区别

三者的区别 1. mysqlclient 特点: 是一个用于Python的MySQL数据库驱动程序,用于与MySQL数据库进行交互。 依赖于MySQL的本地库,因此在安装时需要确保系统上已安装了必要的依赖项,如libmysqlclient-dev等。 性能较好&#xff0c…

基于Netty+InfluxDB+MQTT+Spring Boot的物联网(IoT)项目实现方案

基于NettyInfluxDBMQTTSpring Boot的物联网(IoT)项目实现方案 引言 物联网(IoT)技术近年来发展迅速,广泛应用于智能城市、工业物联网、农业物联网等领域。本文将详细介绍如何使用Netty、InfluxDB、MQTT和Spring Boot…

【大数据】机器学习-----模型的评估方法

一、评估方法 留出法(Holdout Method): 将数据集划分为训练集和测试集两部分,通常按照一定比例(如 70% 训练集,30% 测试集)。训练集用于训练模型,测试集用于评估模型性能。优点&…

Vue.js组件开发-实现输入框与筛选逻辑

在Vue.js组件开发中,实现输入框与筛选逻辑通常涉及创建一个输入框组件,让用户能够输入搜索关键字,并根据这些关键字过滤一个数据列表。 步骤 ‌准备数据‌: 在Vue组件中,准备一个数据列表(通常是一个数组…

AI数字人小程序:解锁个性化智能交互体验

随着人工智能的快速发展,AI数字人迅速崛起,成为了人们日常生活、工作等领域中的重要力量,深受用户的青睐。AI数字人不仅适用于各个领域中,帮助大众高效完成工作等,还能够帮助企业实现数字化发展。目前,AI数…

一个可以把玩的针对WebSocket分段的处理方案

市场上各种高级语言的WebSocket Echo的测试方案不少,但找来找去,愣是没有一个现成的可以针对分段(fragmetation)处理的Echo服务端。分段处理在一些对实时性要求较高的场合非常重要,比如流媒体,实时监控等场…

如何选择合适的服务器?服务器租赁市场趋势分析

服务器租赁市场概览 服务器租赁 market可以分为两种类型:按小时、按月和按年,每种模式都有其特点和适用场景,按小时租赁是最经济实惠的选择,适用于短期需求;按月租赁则适合中长期使用;而按年租赁则是最灵活…

uniapp button 去除边框

在找去除边框的办法时试了好久 css里设置了 border: none; /* 去掉边框 */outline: none; /* 确保点击时不出现轮廓 */压根不行,按钮还是浮在页面上有明显轮廓 最后看到了大佬的文章 https://www.cnblogs.com/menxiaojin/p/13752916.html button::after{border: no…