【知识蒸馏】多任务模型 logit-based 知识蒸馏实战

news/2024/9/24 4:22:12/

一、什么是逻辑(logit)知识蒸馏

Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。

具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。

这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。

需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。

综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。

二、如何进行多任务模型的知识蒸馏

(1)加载学生和教师模型
(2)定义分割蒸馏损失,定义检测蒸馏损失
(3)计算分割蒸馏损失,计算检测蒸馏损失
(4)计算学生模型的分割,检测损失
(5)计算总损失,反向传播

三、实现代码

(1)加载学生和教师模型

# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))

(2)定义分割蒸馏损失,定义检测蒸馏损失
分割损失,参考:【知识蒸馏】语义分割模型逻辑蒸馏实战,对剪枝的模型进行蒸馏训练

# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss

检测损失,参考:【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战

# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature

(3)计算分割蒸馏loss,计算检测蒸馏损失

with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸馏loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 检测蒸馏loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)

(4)计算学生模型的分割,检测损失

det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)

(5)计算总损失,反向传播

seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()

http://www.ppmy.cn/news/1463739.html

相关文章

前后端编程语言和运行环境的理解

我已重新检查了我的回答,并确保信息的准确性。以下是常用的编程语言,以及它们通常用于前端或后端开发,以及相应的框架和运行环境: 前端开发 JavaScript 框架:React, Angular, Vue.js, Ember.js, Backbone.js运行环境:Web 浏览器HTML (HyperText Markup Language) 不是编…

Linux--构建进程池

目录 1.进程池 1.1.我们先完成第一步,创建子进程和信道 1.2. 通过channel控制,发送任务 1.3回收管道和子进程 1.4进行测试 1.5完整代码 1.进程池 进程池其产生原因主要是为了优化大量任务需要多进程完成时频繁创建和删除进程所带来的资源消耗&#…

查询指定会话免打扰

查询指定用户(requestId) 为指定会话(targetId)的设置的免打扰状态。 提示 该设置为用户级别设置。对应的设置接口详见设置指定会话免打扰。 请求方法 POST: https://数据中心域名/conversation/notification/get.json 频率限…

白酒:不同产地白酒的风格特点与比较

云仓酒庄豪迈白酒,作为中国白酒的一部分,其风格特点深受产区的影响。不同产地的白酒,由于自然环境、酿造工艺等因素的差异,形成了各自与众不同的风味和特点。下面让云仓酒庄豪迈白酒来比较一下不同产地白酒的风格特点。 首先&…

VPN的详细理解

VPN(Virtual Private Network,虚拟私人网络)是一种在公共网络上建立加密通道的技术,通过这种技术可以使远程用户访问公司内部网络资源时,实现安全的连接和数据传输。以下是对VPN的详细介绍: 选择代理浏览器…

【夏之以寒-Kafka专栏 01】Kafka的消息是采用Pull模式还是Push模式?

作者名称:夏之以寒 作者简介:专注于Java和大数据领域,致力于探索技术的边界,分享前沿的实践和洞见 文章专栏:夏之以寒-kafka专栏 专栏介绍:本专栏旨在以浅显易懂的方式介绍Kafka的基本概念、核心组件和使用…

【文生漫画系统】小说推文快速生成漫画短视频,搭建一款属于自己的系统,在使用的同时又能运营。

当前热门小说推文的推广方式,又更新了。 从传统的解压视频或者跑酷视频,到现在的漫画形式。涨粉的速度是非常快速的。 所以在做小说推广项目的可以了解一下这款系统。 只需三步就可以生成漫画短视频了 一、文生漫画短视频怎么生成? 二、系…

832. 翻转图像 - 力扣

1. 题目 给定一个 n x n 的二进制矩阵 image ,先 水平 翻转图像,然后 反转 图像并返回 结果 。 水平翻转图片就是将图片的每一行都进行翻转,即逆序。 例如,水平翻转 [1,1,0] 的结果是 [0,1,1]。 反转图片的意思是图片中的 0 全部被…