简单的知识蒸馏

embedded/2024/10/21 14:36:56/

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import keras
from keras import layers
from keras import ops
import numpy as np

# 随着训练的进行,由于总损失(loss)是学生损失(student_loss)和蒸馏损失(distillation_loss)的加权和,
# 模型会同时考虑减少自身的预测损失(即提高预测准确性)和与教师模型预测分布的相似性。通过调整 alpha 参数,您
# 可以控制这两个目标之间的权衡。较大的 alpha 值将使模型更关注自身的预测准确性,而较小的 alpha 值则会使模型
# 更关注与教师模型预测分布的相似性。

# 知识蒸馏一般应该是从复杂的精度高的模型到简单的模型,让学生模型去学习教师模型的预测分布,但这个例子,因为简单的模型也
# 能达到不错的精度,所以没看出来性能提升

#教师模型一般应该是预训练模型,在高分辨率的图片数据集上训练过的,学生模型用来学习教师模型的预测概率分布
# 学生模型的结构和复杂性应该根据任务的要求、数据的特性以及资源限制来仔细选择。如果学生模型过于简单,它可能无法
# 捕捉到教师模型学习到的复杂模式和特征,导致性能不佳。另一方面,如果学生模型过于复杂,虽然它可能能够学习到更多
# 的细节和特征,但也可能导致过拟合,并且在计算资源上可能不高效。因此,在选择学生模型的结构时,需要进行权衡和实验。
# 一种常见的做法是从一个简单的模型开始,并逐步增加其复杂性,以观察性能如何变化。通过这种方式,可以找到在给定资源
# 和任务要求下性能最佳的学生模型结构。

#知识提炼器
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher#教师模型
        self.student = student#学生模型
    #编译,保存一些优化器,损失函数,权重,温度等的参数
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn#学生损失函数
        self.distillation_loss_fn = distillation_loss_fn#蒸馏损失函数
        self.alpha = alpha#蒸馏权重
        self.temperature = temperature#温度参数
    #计算损失
    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)#获取教师模型的预测
        student_loss = self.student_loss_fn(y, y_pred)#根据学生损失函数计算学生损失
        # 计算蒸馏损失。这通常涉及将教师模型和学生模型的预测都通过softmax函数(使用温度参数进行缩放),
        # 然后计算两者之间的差异。这里乘以(self.temperature**2)是一个常见的调整,用于平衡蒸馏损失。
        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)
        # 根据alpha参数,将学生损失和蒸馏损失组合成一个总损失
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss
    def call(self, x):
        return self.student(x)

#教师模型比较大,学生模型比较小
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),#(14,14,256)
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),#(14,14,256)
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),#(7,7,512)
        layers.Flatten(),#(7*7*512)
        layers.Dense(10),
    ],
    name="teacher",
)

student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)
student_scratch = keras.models.clone_model(student)#新模型与原模型具有相同的结构,但不用原模型的权重,优化器等等

batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

teacher.summary()#3*3*1*256+256,3*3*256*512+512,flatten:7*7*512,把像素值展平成向量,25088*10+10

teacher.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#损失函数:多元交叉熵
    metrics=[keras.metrics.SparseCategoricalAccuracy()],#指标:准确率
)

teacher.fit(x_train, y_train, epochs=5)

teacher.evaluate(x_test, y_test)

distiller = Distiller(student=student, teacher=teacher)#构建知识提炼器

distiller.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],#度量指标
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#学生损失函数:多元交叉熵
    distillation_loss_fn=keras.losses.KLDivergence(),#知识提炼损失:kld
    alpha=0.1,#提炼权重(用来设定学生和提炼损失的占比)
    temperature=10,#温度,缩放系数
)

distiller.fit(x_train, y_train, epochs=3)#提炼教师到学生(让学生模型能学习教师模型的预测分布)

distiller.evaluate(x_test, y_test)

student_scratch.compile(#这个拷贝模型的职责就是衡量知识提炼中学生模型究竟从教师模型中学到了多少东西
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],
)

student_scratch.fit(x_train, y_train, epochs=3)

student_scratch.evaluate(x_test, y_test)


http://www.ppmy.cn/embedded/33644.html

相关文章

Redis-分片机制

概述 业务需要:由于单台redis内存容量是有限的,无法实现海量的数据实现缓存存储 概念:由多个redis节点协助工作的机制就是redis的分片机制 作用:为了实现redis扩容 特点:分片机制把该机制中包含的多台redis缓存服务…

【前端学习——css】css实现给背景的图片加模糊

我在制作自己的博客的时候,打算做个封面,封面用半透明颜色盖住了预览图,上面印上了文字,但背景图太乱了,所以打算给背景加模糊效果。 效果 方法 主要就是利用这个属性 backdrop-filter: blur(5px);属性很简单&#x…

基于纯JavaScript实现的MODBUS-RTU(串口和TCP) modbus-serial

modbus-serial 如果你需要使用JavaScript来操作一台RS458的设备,那么你一定不能错过这个库 modbus-serial。 安装和使用 npm install modbus-serial支持的功能码 功能码函数FC1 读取读线圈寄存器readCoils(coil, len) FC2 读离散输入寄存器readDiscreteInputs(a…

C语言实战项目--贪吃蛇

贪吃蛇是久负盛名的游戏之一,它也和俄罗斯⽅块,扫雷等游戏位列经典游戏的行列。在编程语言的教学中,我们以贪吃蛇为例,从设计到代码实现来提升大家的编程能⼒和逻辑能⼒。 在本篇讲解中,我们会看到很多陌生的知识&…

linux下载安装JDK

查看系统是否自带 jdk java -version 一、jdk下载安装 jdk11下载 上传到 linux 以下说明已下载 解压 tar -xzvf jdk-11.0.23_linux-x64_bin.tar.gz 查看是否安装成功 二、linux配置JDK环境 sudo vim /etc/profile JAVA_HOME/may2024/jdk-11.0.23 JRE_HOME$JAVA_HOME/…

学习100个Unity Shader (16) --- 程序纹理简述

文章目录 理解参考 理解 程序纹理顾名思义,就是通过代码生成的纹理,然后传入材质,生成图像。 假设,给一个模型添加了材质,并赋予了一个shader。shader中有一个纹理属性叫_MainTex。 程序纹理简单来说就是,…

C++静态数组和C语言静态数组的区别( array,int a[])

目录 一、区别 1、越界读,检查不出来 2、越界写,抽查,可能检查不出来,有局限性 二、array缺点 一、区别 C语言的静态数组int a[]; 静态数组的越界检查不稳定的: 1、越界读,检查不出来 2、越界写&#x…

二叉搜索树

一、概念 二叉搜索树又称二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: 若它的左子树不为空,则左子树上所有节点的值都小于根节点的值 若它的右子树不为空,则右子树上所有节点的值都大于根节点的值 它的左右子树也分…