政安晨:【Keras机器学习示例演绎】(三十三)—— 知识提炼

ops/2024/12/22 23:27:55/

目录

设置

构建 Distiller() 类

创建学生和教师模型

准备数据集

培训教师

将教师模型蒸馏给学生模型

从头开始训练学生进行比较


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实施经典知识蒸馏。

知识蒸馏简介

知识蒸馏(Knowledge Distillation)是一种模型压缩程序,在这种程序中,一个小的(学生)模型被训练成与一个大的预训练(教师)模型相匹配。通过最小化损失函数,将知识从教师模型转移到学生模型,目的是匹配软化的教师对数以及地面实况标签。

通过在 softmax 中应用 "温度 "缩放函数来软化对数,从而有效地平滑概率分布,并揭示教师所学的类间关系。

设置

import osimport keras
from keras import layers
from keras import ops
import numpy as np

构建 Distiller() 类

自定义 Distiller() 类覆盖了编译、计算损失和调用模型方法。

为了使用蒸馏器,我们需要:

一个训练有素的教师模型
一个要训练的学生模型
一个关于学生预测与地面实况之间差异的学生损失函数
学生软预测与教师软标签之间差值的蒸馏损失函数以及温度
用于加权学生损失和蒸馏损失的 alpha 因子
学生的优化器和(可选)性能评估指标

在 compute_loss 方法中,我们对教师和学生进行前向传递,计算损失,并分别按 alpha 和 1 - alpha 对 student_loss 和 distillation_loss 进行加权。注意:只更新学生权重。

class Distiller(keras.Model):def __init__(self, student, teacher):super().__init__()self.teacher = teacherself.student = studentdef compile(self,optimizer,metrics,student_loss_fn,distillation_loss_fn,alpha=0.1,temperature=3,):"""Configure the distiller.Args:optimizer: Keras optimizer for the student weightsmetrics: Keras metrics for evaluationstudent_loss_fn: Loss function of difference between studentpredictions and ground-truthdistillation_loss_fn: Loss function of difference between softstudent predictions and soft teacher predictionsalpha: weight to student_loss_fn and 1-alpha to distillation_loss_fntemperature: Temperature for softening probability distributions.Larger temperature gives softer distributions."""super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fnself.alpha = alphaself.temperature = temperaturedef compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False):teacher_pred = self.teacher(x, training=False)student_loss = self.student_loss_fn(y, y_pred)distillation_loss = self.distillation_loss_fn(ops.softmax(teacher_pred / self.temperature, axis=1),ops.softmax(y_pred / self.temperature, axis=1),) * (self.temperature**2)loss = self.alpha * student_loss + (1 - self.alpha) * distillation_lossreturn lossdef call(self, x):return self.student(x)

创建学生和教师模型


首先,我们创建一个教师模型和一个较小的学生模型。这两个模型都是使用 Sequential() 创建的卷积神经网络,但也可以是任何 Keras 模型。

# Create the teacher
teacher = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="teacher",
)# Create the student
student = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="student",
)# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

准备数据集


用于训练教师和提炼教师的数据集是 MNIST,如果选择合适的模型,该过程也可用于任何其他数据集,如 CIFAR-10。学生和教师都在训练集上进行训练,并在测试集上进行评估。

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

培训教师


知识提炼过程中,我们假定教师是经过训练且固定不变的。因此,我们首先按照常规方法在训练集上训练教师模型。

# Train teacher as usual
teacher.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],
)# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/51875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/51875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/51875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/51875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/51875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760[0.09044107794761658, 0.978100061416626]

将教师模型蒸馏给学生模型


我们已经训练了教师模型,现在只需初始化 Distiller(student, teacher) 实例,用所需的损失、超参数和优化器编译()它,然后将教师模型蒸馏为学生模型。

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(optimizer=keras.optimizers.Adam(),metrics=[keras.metrics.SparseCategoricalAccuracy()],student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),distillation_loss_fn=keras.losses.KLDivergence(),alpha=0.1,temperature=10,
)# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629[0.017046602442860603, 0.969200074672699]

从头开始训练学生进行比较


我们还可以在没有教师的情况下,从头开始训练一个等效的学生模型,以评估通过知识提炼获得的性能增益。

# Train student as doen usually
student_scratch.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],
)# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/31875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737[0.0629437193274498, 0.9778000712394714]

如果教师训练了 5 个完整的历元,学生在教师的基础上提炼了 3 个完整的历元,那么在这个示例中,与从头开始训练相同的学生模型相比,甚至与教师本身相比,您都会体验到性能的提升。

教师的准确率应该在 97.6% 左右,从头开始训练的学生的准确率应该在 97.6% 左右,而经过提炼的学生的准确率应该在 98.1% 左右。删除或尝试不同的种子,使用不同的权重初始化。



http://www.ppmy.cn/ops/31591.html

相关文章

如何使用 GPT API 从 PDF 出版物导出研究图表?

原文地址:how-to-use-gpt-api-to-export-a-research-graph-from-pdf-publications 揭示内部结构——提取研究实体和关系 2024 年 2 月 6 日 介绍 研究图是研究对象的结构化表示,它捕获有关实体的信息以及研究人员、组织、出版物、资助和研究数据之间的关…

多国语言免费在线客服系统源码,网站在线客服系统,网页在线客服软件在线聊天通讯平台

详情介绍 多国语言免费在线客服系统源码,网站在线客服系统,网页在线客服软件在线聊天通讯平台 新款在线客服系统全开源无加密:多商户、国际化多语言、智能机器人、自动回复、语音聊天、 文件发送、系统强力防黑加固、不限坐席、国际外贸、超多功能 支持手机移动端和PC网页…

nginx--配置文件

组成 主配置文件:nginx.conf 子配置文件:include conf.d/*.conf 协议相关的配置文件:fastcgi uwsgi scgi等 mime.types:⽀持的mime类型,MIME(Multipurpose Internet Mail Extensions)多用途互联⽹网邮件扩展类型&…

自动驾驶-第02课软件环境基础(ROSCMake)

1. 什么是ros 2. 为什么使用ros 3. ROS通信 3.1 Catkin编译系统

什么是TCP粘包?

TCP粘包 数据的接收和发送是无关的,read()/recv() 函数不管数据发送了多少次,都会尽可能多的接收数据。也就是说,read()/recv() 和 write()/send() 的执行次数可能不同。 举个栗子 write()/send() 重复执行三次,每次都发送字符…

什么是binutils-arm-linux-gnueabi

2024年5月3日,周五晚上 binutils-arm-linux-gnueabi 是针对 ARM 架构的 Linux 系统开发的 GNU Binutils 工具链。Binutils 是一组用于汇编、链接和转换目标文件的工具,包括 as (汇编器)、ld (链接器)、objcopy (目标文件转换工具) 等。 binutils-arm-li…

LLM大语言模型原理、发展历程、训练方法、应用场景和未来趋势

LLM,全称Large Language Model,即大型语言模型。LLM是一种强大的人工智能算法,它通过训练大量文本数据,学习语言的语法、语义和上下文信息,从而能够对自然语言文本进行建模。这种模型在自然语言处理(NLP&am…

JAVA设计模式

**************************************************************************************************************************************************************************** 1、设计模式概述 【1】前辈们对代码开发经验的总结,解决问题的套路。是用来提…