文章目录
- 1. 知识蒸馏的作用
- 2. 知识蒸馏的一般步骤
- 1. 准备数据集
- 2. 训练教师模型
- 3. 得到教师模型输出
- 4. 准备学生模型
- 5. 定义损失函数
- 6. 进行知识蒸馏训练
- 7. 调节温度参数
- 3. Deit中选用的教师模型为什么是ConvNet?
- 4. 软标签和硬标签
- 5. 知识蒸馏在Deit代码中的体现
1. 知识蒸馏的作用
知识蒸馏是一种模型压缩技术,通过训练一个小的模型来近似一个大的模型,用于将一个大的模型(称为教师模型)的知识传递给一个小的模型(称为学生模型),从而提高学生模型的性能和准确度。在蒸馏过程中,大模型被称作教师模型(teacher model),小模型被称作学生模型(student model)。
2. 知识蒸馏的一般步骤
1. 准备数据集
准备一个用于训练的数据集。这个数据集一般教师模型已经预测的标签和相应的输入数据。
2. 训练教师模型
使用教师模型在准备好的模型上进行训练,使其达到一个较高的准确率。可以使用一个已经预训练好的模型,也可以从头开始训练模型。
教师模型是一个基于卷积神经网络(CNN)的预训练模型,如BiT-M或BiT-L。这些模型在大规模的数据集上进行了预训练,具有很强的泛化能力和鲁棒性。
3. 得到教师模型输出
使用训练好的教师模型对训练数据集的输入进行预测,得到教师模型的输出结果。这些结果将作为学生模型的目标标签。
4. 准备学生模型
准备一个较小的模型作为蒸馏模型。
学生模型是一个基于变换器(Transformer)的图像分类模型,如ViT。这些模型在小规模的数据集上进行了训练,具有较低的计算成本和内存占用
5. 定义损失函数
知识蒸馏的关键就是定义一个合适的损失函数,用于比较学生模型的输出和教师模型输出的区别。
蒸馏损失函数是二元交叉熵(Binary Cross Entropy),而不是KL散度(Kullback-Leibler Divergence)。这样做可以减少计算量和内存占用,同时也可以提高准确率
6. 进行知识蒸馏训练
用同样的数据集训练该学生模型,但是不仅使用真实标签作为目标,还使用教师模型的输出作为目标,从而让学生模型学习教师模型的软标签(即概率分布),计算损失函数并反向传播,直至模型收敛。
蒸馏目标是教师模型输出的硬标签(即最大概率对应的类别),而不是软标签。这样做可以避免教师模型的错误预测影响学生模型,或者增加训练难度从而提高泛化能力
7. 调节温度参数
在计算损失函数时,可以为教师模型和学生模型定义一个温度参数。这个参数可以控制教师模型的软目标相对于硬目标的重要性。
3. Deit中选用的教师模型为什么是ConvNet?
论文中也没有给出明确的解释,但是给出的推测是:卷积神经网络具有transformer所没有的inductive bias,所以经过卷积网络训练的教师网络的输出是带有inductive bias的特性的。
原文内容:
The fact that the convnet is a better teacher is probably due to the inductive bias inherited by the transformers through distillation
4. 软标签和硬标签
硬标签是指传统的“独热编码”形式的标签,即对于每个类别只有一个元素为1,其余元素为0。例如,对于一个10类的分类问题,硬标签可能是一个长度为10的向量,只有在正确的类别索引处有1,其他位置都为0。
软标签是一种概率分布的形式,它表示了教师模型对于每个类别的置信度。软标签是由教师模型的输出通过softmax函数计算而得到的。这意味着软标签中的每个元素都是在0到1之间的值,并且所有元素的和等于1。软标签可以看作是对每个类别的概率预测。
5. 知识蒸馏在Deit代码中的体现
1.在models.py文件中,定义了一个ViTDistilled类,用于构建学生模型。这个类继承了ViT类,但是在输出层增加了一个dist_token,用于接收教师模型的输出作为蒸馏目标。
2.在losses.py文件中,定义了一个DistillationLoss类,用于计算知识蒸馏的损失函数。这个类包括了二元交叉熵损失和注意力匹配损失两部分。
3.在main.py文件中,定义了一个DeiT类,用于封装学生模型和教师模型的训练和评估逻辑。这个类在初始化时会加载教师模型,并在训练时使用教师模型的输出作为蒸馏目标和正则化项。
4.在hubconf.py文件中,提供了一些预训练的学生模型和教师模型,可以直接加载或者下载使用。这些模型都是基于知识蒸馏方法训练得到的。