关于Deit中的知识蒸馏(Knowledge Distillation)详解

news/2024/11/16 21:38:03/

文章目录

  • 1. 知识蒸馏的作用
  • 2. 知识蒸馏的一般步骤
    • 1. 准备数据集
    • 2. 训练教师模型
    • 3. 得到教师模型输出
    • 4. 准备学生模型
    • 5. 定义损失函数
    • 6. 进行知识蒸馏训练
    • 7. 调节温度参数
  • 3. Deit中选用的教师模型为什么是ConvNet?
  • 4. 软标签和硬标签
  • 5. 知识蒸馏在Deit代码中的体现

1. 知识蒸馏的作用

知识蒸馏是一种模型压缩技术,通过训练一个小的模型来近似一个大的模型,用于将一个大的模型(称为教师模型)的知识传递给一个小的模型(称为学生模型),从而提高学生模型的性能和准确度。在蒸馏过程中,大模型被称作教师模型(teacher model),小模型被称作学生模型(student model)。

2. 知识蒸馏的一般步骤

1. 准备数据集

准备一个用于训练的数据集。这个数据集一般教师模型已经预测的标签和相应的输入数据。

2. 训练教师模型

使用教师模型在准备好的模型上进行训练,使其达到一个较高的准确率。可以使用一个已经预训练好的模型,也可以从头开始训练模型。
教师模型是一个基于卷积神经网络(CNN)的预训练模型,如BiT-M或BiT-L。这些模型在大规模的数据集上进行了预训练,具有很强的泛化能力和鲁棒性。

3. 得到教师模型输出

使用训练好的教师模型对训练数据集的输入进行预测,得到教师模型的输出结果。这些结果将作为学生模型的目标标签。

4. 准备学生模型

准备一个较小的模型作为蒸馏模型。
学生模型是一个基于变换器(Transformer)的图像分类模型,如ViT。这些模型在小规模的数据集上进行了训练,具有较低的计算成本和内存占用

5. 定义损失函数

知识蒸馏的关键就是定义一个合适的损失函数,用于比较学生模型的输出和教师模型输出的区别。
蒸馏损失函数是二元交叉熵(Binary Cross Entropy),而不是KL散度(Kullback-Leibler Divergence)。这样做可以减少计算量和内存占用,同时也可以提高准确率

6. 进行知识蒸馏训练

用同样的数据集训练该学生模型,但是不仅使用真实标签作为目标,还使用教师模型的输出作为目标,从而让学生模型学习教师模型的软标签(即概率分布),计算损失函数并反向传播,直至模型收敛。
蒸馏目标是教师模型输出的硬标签(即最大概率对应的类别),而不是软标签。这样做可以避免教师模型的错误预测影响学生模型,或者增加训练难度从而提高泛化能力

7. 调节温度参数

在计算损失函数时,可以为教师模型和学生模型定义一个温度参数。这个参数可以控制教师模型的软目标相对于硬目标的重要性。

3. Deit中选用的教师模型为什么是ConvNet?

论文中也没有给出明确的解释,但是给出的推测是:卷积神经网络具有transformer所没有的inductive bias,所以经过卷积网络训练的教师网络的输出是带有inductive bias的特性的。
原文内容:

The fact that the convnet is a better teacher is probably due to the inductive bias inherited by the transformers through distillation

4. 软标签和硬标签

硬标签是指传统的“独热编码”形式的标签,即对于每个类别只有一个元素为1,其余元素为0。例如,对于一个10类的分类问题,硬标签可能是一个长度为10的向量,只有在正确的类别索引处有1,其他位置都为0。
软标签是一种概率分布的形式,它表示了教师模型对于每个类别的置信度。软标签是由教师模型的输出通过softmax函数计算而得到的。这意味着软标签中的每个元素都是在0到1之间的值,并且所有元素的和等于1。软标签可以看作是对每个类别的概率预测。

5. 知识蒸馏在Deit代码中的体现

1.在models.py文件中,定义了一个ViTDistilled类,用于构建学生模型。这个类继承了ViT类,但是在输出层增加了一个dist_token,用于接收教师模型的输出作为蒸馏目标。

2.在losses.py文件中,定义了一个DistillationLoss类,用于计算知识蒸馏的损失函数。这个类包括了二元交叉熵损失和注意力匹配损失两部分。

3.在main.py文件中,定义了一个DeiT类,用于封装学生模型和教师模型的训练和评估逻辑。这个类在初始化时会加载教师模型,并在训练时使用教师模型的输出作为蒸馏目标和正则化项。

4.在hubconf.py文件中,提供了一些预训练的学生模型和教师模型,可以直接加载或者下载使用。这些模型都是基于知识蒸馏方法训练得到的。


http://www.ppmy.cn/news/989068.html

相关文章

卷积的意义及其应用

卷积的意义及其应用 卷积的定义 我们将形如 ∫ − ∞ ∞ f ( τ ) g ( x − τ ) d τ \int^\infty_{-\infty} f(τ)g(x-τ)dτ ∫−∞∞​f(τ)g(x−τ)dτ 的式子称之为f(x)与g(x)的卷积记为 h ( x ) ( f ∗ g ) ( x ) h(x…

目前可以实现用手机操作水质自动采样器吗

利用自动采样器进行水样采集可以说节省很大的人力物力,但是有时为了采到更具代表性的水样,我们需要对沟渠、深井、排污口等特殊场景进行采样。像这些狭小的空间领域采样就有点困难,对现场工作人员就带来了一些难题。所以也需要一款可以在井下…

【Python数据分析】Python常用内置函数(一)

🎉欢迎来到Python专栏~Python常用内置函数(一) ☆* o(≧▽≦)o *☆嗨~我是小夏与酒🍹 ✨博客主页:小夏与酒的博客 🎈该系列文章专栏:Python学习专栏 文章作者技术和水平有限,如果文…

zookeeper集群节点替换方案

Zookeeper集群节点替换 任务描述 Kafka集群依赖的Zookeeper集群中节点替换 解决步骤 解决思路: 先将新节点加入zookeeper集群依次修改Kafka配置并重启依次修改zookeeper配置并重启(leader节点最后重启 实践操作: 1)操作新机…

JS前端读取本地上传的File文件对象内容(包括Base64、text、JSON、Blob、ArrayBuffer等类型文件)

读取base64图片File file2Base64Image(file, cb) {const reader new FileReader();reader.readAsDataURL(file);reader.onload function (e) {cb && cb(e.target.result);//即为base64结果}; }, 读取text、JSON文件File readText(file, { onloadend } {}) {const re…

Redis简介,设置redis内存大小,设置redis淘汰机制,查看内存占用情况,内存占用分析

为什么使用Redis缓存数据库 我们日常的开发,无非是对数据的处理。程序的定义也可以这样狭义的解释:算法数据。可见数据库是多么重要的工具。但是关系型数据库的读写能力在200-1000次/秒不等,服务器好点可能更多,这导致在高并发的…

draw up a plan

爱情是美好的,却不是唯一的。爱情只是属于个人化的感情。 推荐一篇关于爱情的美文: 在一个小镇上,有一家以制作精美巧克力而闻名的手工巧克力店,名叫“甜蜜之爱”。这家巧克力店是由一位名叫艾玛的年轻女性经营的,她对…

第19节:医学分析,分析轨迹数据以了解视频游戏对人类短期和长期记忆的影响

分析轨迹数据以了解视频游戏对人类短期和长期记忆的影响 1.简介 1.1背景和动机 这些数据是在一项关于探索新奇事物对儿童学习成功的影响的研究中记录的。 研究组由患有不同类型多动症的儿童和对照组组成。 在实验中,两组(患有多动症和对照组)都必须在不同的三天参加研究。…