知识蒸馏:如何让小模型继承大模型的智慧,提升效率不牺牲效果

news/2025/2/21 7:28:46/

我们今天给大家分享一篇知识蒸馏的工作。

我们知道,现如今模型各种各样,效果方面屡创新高。但是,有的时候,效果提升会有效率的牺牲。那么知识蒸馏,就是能够让保证效果的同时,提升效率。这篇文章将给大家来介绍一篇知识蒸馏相关的知识。

原工作链接:https://arxiv.org/pdf/1503.02531

image.png

谷歌的三位大佬所著。

一、研究背景:模型训练与部署的矛盾困境

在机器学习的实际应用中,训练模型和部署模型就像两个性格迥异的 “小伙伴”,有着不同的需求。训练模型的时候,我们希望它能从大量数据里 “挖” 出有用的信息,哪怕计算量再大、耗时再久也没关系,就像一个耐心的矿工,慢慢挖掘宝藏。比如在语音识别和物体识别这些任务里,训练模型要处理海量、高度冗余的数据集,从中提取出关键的特征和规律。

但是,当模型训练好要部署到实际场景中时,情况就大不一样了。这时候,对模型的延迟和计算资源的要求变得非常严格。想象一下,你用手机语音搜索的时候,如果模型响应很慢,或者特别耗电,你肯定会觉得体验很差。所以,部署的模型需要能快速给出结果,还不能占用太多资源。

就像昆虫有幼虫和成虫两种形态,分别适应不同的生存需求一样。在机器学习里,我们也应该根据训练和部署的不同需求,采用不同的模型策略。有时候,为了从数据里更好地提取结构,我们可以先训练一个 “笨重” 但强大的模型,这个模型可以是多个单独训练的模型组成的集合,也可以是一个用很强的正则化方法(比如 dropout)训练出来的大模型。然后,再想办法把这个 “笨重” 模型的知识,转移到一个更小巧、适合部署的模型里,这就是知识蒸馏技术诞生的初衷。

而且,传统上大家总觉得模型的知识就体现在它学习到的参数值里,这种想法限制了我们对模型知识的理解。其实,从更抽象的角度看,模型的知识是一种从输入向量到输出向量的映射关系。那些 “笨重” 的模型在学习区分大量类别时,除了关注正确答案的概率,它们对错误答案分配的概率也包含着很多信息,这些信息能告诉我们模型是怎么进行泛化的。比如一张宝马车的图片,被误认成垃圾车的概率虽然很小,但比误认成胡萝卜的概率要大得多,这就体现了模型对不同类别之间相似性的理解。

另外,我们训练模型的时候,通常是希望它在新数据上也能表现得好,也就是泛化能力强。但实际上,我们训练模型用的目标函数往往只是让它在训练数据上表现好,因为我们不知道怎么让它更好地泛化。 不过,在知识蒸馏的过程中,我们可以让小模型学习大模型的泛化方式,这样小模型在测试数据上的表现就可能比普通训练的小模型要好。

二、知识蒸馏的方法:让小模型 “继承” 大模型的智慧

知识蒸馏的核心思想,就是把大模型(也就是前面说的 “笨重” 模型)的知识转移到小模型里。具体是怎么做的呢?这就要说到 “软目标”(soft targets)这个关键概念了。

image.png

(一)什么是软目标

神经网络通常会用一个 “softmax” 输出层来计算每个类别的概率。公式qi=exp(zi/T)∑jexp(zj/T)qi​=∑j​exp(zj​/T)exp(zi​/T)​,这里的TT就是温度参数,一般情况下TT会设为 1 。当TT取值比较低的时候,概率分布会比较 “硬”,也就是说模型对某个类别的预测会比较确定;而当TT取值高的时候,概率分布就会变得更 “软”,模型对各个类别的预测概率会更均匀一些,这样就能包含更多信息。

知识蒸馏里,我们把大模型在高温度TT下产生的概率分布作为 “软目标”,用来训练小模型。比如说,大模型对一张图片的预测,原本在T=1T=1时,可能对某个类别的概率预测是 0.9,其他类别概率都很小。但当TT升高后,各个类别的概率会变得更分散,这些概率里就包含了大模型对不同类别之间关系的理解,比如哪些类别容易混淆等信息。小模型通过学习这些软目标,就能学到大模型的一些 “思考方式”。

(二)知识蒸馏的具体过程

image.png

在最简单的知识蒸馏形式里,我们会用大模型在高温度TT下对一个 “转移集”(可以是原始训练集,也可以是单独的一个数据集)生成软目标分布。然后,用这些软目标来训练小模型,训练小模型的时候也用同样的高温度TT。等小模型训练好之后,再把温度调回 1 来进行正常的预测

如果转移集里的数据有正确标签,我们还可以进一步改进训练方法。可以用两个不同的目标函数的加权平均来训练小模型。第一个目标函数是小模型和软目标之间的交叉熵,计算这个交叉熵的时候,小模型的 softmax 用和生成软目标时一样的高温度TT;第二个目标函数是小模型和正确标签之间的交叉熵,这个计算用温度T=1T=1 。一般来说,第二个目标函数的权重会设置得比较低。

为什么要这么做呢?因为软目标产生的梯度大小和1/T21/T2成正比,所以在同时使用硬目标(正确标签)和软目标的时候,要把软目标产生的梯度乘以T2T2,这样才能保证当我们调整蒸馏温度TT来实验不同的超参数时,硬目标和软目标的相对贡献大致保持不变。

image.png

(三)匹配 logits 是知识蒸馏的特殊情况

这里还有个有趣的发现,匹配大模型和小模型的 logits(也就是 softmax 层的输入)其实是知识蒸馏的一种特殊情况。在训练小模型时,每个样本都会对小模型的每个 logit 产生一个交叉熵梯度。当蒸馏温度TT比较高,并且和 logits 的大小相比可以忽略不计时,经过一些数学推导可以发现,蒸馏就相当于最小化小模型和大模型 logits 之差的平方的一半,前提是每个转移样本的 logits 都已经单独进行了零均值化处理。

不过,在较低温度下,蒸馏会对那些比平均值小很多的 logits 关注较少。这既有好处也有坏处,好处是这些 logits 可能在训练大模型时受损失函数的约束很小,所以可能有很多噪声,忽略它们能减少干扰;但坏处是它们也可能包含了大模型学到的有用知识。到底哪个影响更大,还得通过实际实验来判断。研究发现,当小模型太小,没办法完全学习大模型的所有知识时,中等温度的蒸馏效果最好,这说明适当忽略那些非常小的 logits 可能是有帮助的。

三、实验验证:知识蒸馏真的有效吗?

为了验证知识蒸馏的效果,研究人员进行了一系列实验,涵盖了图像识别和语音识别等领域,下面我们一起来看看这些有趣的实验结果。

(一)MNIST 数据集上的实验

点击知识蒸馏:如何让小模型继承大模型的智慧,提升效率不牺牲效果查看全文。


http://www.ppmy.cn/news/1573824.html

相关文章

GAMES101-现代计算机图形学入门笔记

主讲老师:闫令琪,此处仅做个人笔记使用。如果我的分享对你有帮助,请记得点赞关注不迷路。 课程链接如下:GAMES101-现代计算机图形学入门-闫令琪_哔哩哔哩_bilibili 课程分为四部分:光栅化、几何、光线追踪、模拟 图形…

uniapp录制语音

给大家讲解瞎 录制语音 的功能,这部分主要涉及到以下几个步骤:开始录音、停止录音、播放录音的功能 1.开始录音 (startRecording 函数) 当用户点击 开始录音 按钮时,调用 startRecording 函数开始录音。录音通过 uni.getRecorderManager() …

实验流量统计设计

当我们需要统计实验中每个分支的实际进入次数时,如何设计一个高效、可靠且对业务影响最小的方案,成为了关键。以下是几种常见的流量统计方案的分析与实现设计 目标 不影响实际业务使用,不应该因为汇报错误,导致灰度、甚至实际业…

Qt中使用QPdfWriter类结合QPainter类绘制并输出PDF文件

一.类的介绍 1.QPdfWriter介绍 Qt中提供了一个直接可以处理PDF的类,这就是QPdfWriter类。 (1)PDF文件生成 支持创建新的PDF文件或覆盖已有文件,通过构造函数直接绑定文件路径或QFile对象; 默认生成矢量图形PDF&#…

【DeepSeek 系列】DeepSeekMoE

文章目录 Transformers 中的 MoETransformer语言模型通用的MoE架构 DeepSeekMoE架构细粒度专家分割共享专家隔离负载均衡考虑 模型预训练不同尺寸模型超参数概览DeepSeekMoE 2B训练数据基础设施超参数 DeepSeekMoE 16B训练数据超参数 DeepSeekMoE 16B 的对齐训练数据超参数 总结…

Unity教程(二十一)技能系统 基础部分

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程(零)Unity和VS的使用相关内容 Unity教程(一)开始学习状态机 Unity教程(二)角色移动的实现 Unity教程(三)角色跳跃的实现 Unity教程&…

【Linux】序列化、守护进程、应用层协议HTTP、Cookie和Session

⭐️个人主页:小羊 ⭐️所属专栏:Linux 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 1、序列化和反序列化2、守护进程2.1 什么是进程组?2.2 什么是会话? 3、应用层协议HTTP3.1 HTTP协议3.2 HT…

Mentalab Explore Pro:第三代移动 EEG 设备,开启便携式脑电研究新时代

Mentalab推出的Explore Pro是一款专为研究和工业领域设计的第三代移动脑电图(EEG)设备。它凭借很高的精度和小巧的尺寸,为脑电研究提供了新的可能性。这款设备凭借强大的功能和灵活的设计,成为研究人员在各种实验环境中重要的工具…