知识蒸馏教程 Knowledge Distillation Tutorial

ops/2025/2/6 1:10:18/

来自于:Knowledge Distillation Tutorial
将大模型蒸馏为小模型,可以节省计算资源,加快推理过程,更高效的运行。

使用CIFAR-10数据集

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasetsdevice = "cuda" #CPU也可
transforms_cifar = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

定义模型

定义两个结构相似,只是在宽度和深度不同的模型。
教师模型DeepNN

# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):def __init__(self, num_classes=10):super(DeepNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

学生模型LightNN

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):def __init__(self, num_classes=10):super(LightNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

在这里插入图片描述

训练并测试模型

def train(model, train_loader, epochs, learning_rate, device):criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)model.train()for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:# inputs: A collection of batch_size images# labels: A vector of dimensionality batch_size with integers denoting class of each imageinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes# labels: The actual labels of the images. Vector of dimensionality batch_sizeloss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")def test(model, test_loader, device):model.to(device)model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Accuracy: {accuracy:.2f}%")return accuracy
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

DeepNN的参数量为1,186,986,准确率为75.98%。
LightNN的参数量为267,738,准确率为70.65%。

total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

知识蒸馏

教师模型和学生模型都输出了关于类别的概率分布,假设认为,经过训练的教师模型输出的softmax结果携带了更多的信息,有助于提高学生模型的准确率。例如,在默认情况下,汽车、火车、摩托车的对应的label为 [1,0,0],经过训练的教师模型输出结果可能是 [0.6,0.2,0.2],而对于汽车、狗、猫,教师模型输出的结果可能是[0.8,0.1,0.1],汽车和火车、摩托车要比狗、猫更相似。让学生模型学习到教师模型的这部分知识,就称为知识蒸馏

学生模型与真实值的损失使用交叉熵损失。
学生模型与教师模型的损失使用KL散度损失。

蒸馏过程中,冻结教师模型,只训练学生模型。

增加参数:

  • T:温度,温度控制着输出分布的平滑度。较大的 T 会导致更平滑的分布,因此较小的概率会得到更大的提升。
  • soft_target_loss_weight:学生模型与教师模型的损失的权重。
  • ce_loss_weight:学生模型与真实值的损失的权重。
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weightswith torch.no_grad():teacher_logits = teacher(inputs)# Forward pass with the student modelstudent_logits = student(inputs)#Soften the student logits by applying softmax first and log() secondsoft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")#Test Accuracy: 70.49%
#Teacher accuracy: 75.98%
#Student accuracy without teacher: 70.65%
#Student accuracy with CE + KD: 70.49%

CosineEmbeddingLoss

蒸馏的目标是让学生模型学习教师模型的知识,那么不只是学习最终的输出分布,也可以学习教师模型的内部表示hidden states。
可以比较两个模型的中间输出向量,使用CosineEmbeddingLoss。
在前面的模型中,教师模型flatten输出维度为2048,而学生模型为1024,因此在教师模型中加入额外池化层,让两个模型在同一个维度。

class ModifiedDeepNNCosine(nn.Module):def __init__(self, num_classes=10):super(ModifiedDeepNNCosine, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)flattened_conv_output = torch.flatten(x, 1)x = self.classifier(flattened_conv_output)flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)return x, flattened_conv_output_after_pooling# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):def __init__(self, num_classes=10):super(ModifiedLightNNCosine, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)flattened_conv_output = torch.flatten(x, 1)x = self.classifier(flattened_conv_output)return x, flattened_conv_output# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())

在这里插入图片描述
训练函数和测试函数也随之发生变化。

def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()cosine_loss = nn.CosineEmbeddingLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.to(device)student.to(device)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Forward pass with the teacher model and keep only the hidden representationwith torch.no_grad():_, teacher_hidden_representation = teacher(inputs)# Forward pass with the student modelstudent_logits, student_hidden_representation = student(inputs)# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test_multiple_outputs(model, test_loader, device):model.to(device)model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs, _ = model(inputs) # Disregard the second tensor of the tuple_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Accuracy: {accuracy:.2f}%")return accuracy# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
#Test Accuracy: 70.12%

Intermediate regressor run

对于高维度向量,余弦相似度通常比欧几里得距离效果更好,但我们处理的是每个具有 1024 个分量的向量,因此更难提取有意义的相似性。此外,正如我们所提到的,从理论上讲,推动教师和学生的隐藏表示相匹配是不被支持的。我们没有充分的理由应该追求这些向量的 1:1 匹配。
作者认为前面的蒸馏,学生模型和教师模型学习的是向量,即学习的是torch.flatten(x, 1),是一个向量,表达能力有限。因此选取 flatten 的前一层,学习卷积层的输出特征图。
教师模型的特征图shape为[128, 32, 8, 8],学生模型的特征图为[128, 16, 8, 8],需要添加一个卷积层,对齐维度。
在这里插入图片描述
在学生模型中加入了regressor层。

class ModifiedDeepNNRegressor(nn.Module):def __init__(self, num_classes=10):super(ModifiedDeepNNRegressor, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)conv_feature_map = xx = torch.flatten(x, 1)x = self.classifier(x)return x, conv_feature_mapclass ModifiedLightNNRegressor(nn.Module):def __init__(self, num_classes=10):super(ModifiedLightNNRegressor, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# Include an extra regressor (in our case linear)self.regressor = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1))self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)regressor_output = self.regressor(x)x = torch.flatten(x, 1)x = self.classifier(x)return x, regressor_output
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()mse_loss = nn.MSELoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.to(device)student.to(device)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Again ignore teacher logitswith torch.no_grad():_, teacher_feature_map = teacher(inputs)# Forward pass with the student modelstudent_logits, regressor_feature_map = student(inputs)# Calculate the losshidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")#Teacher accuracy: 75.98%
#Student accuracy without teacher: 70.65%
#Student accuracy with CE + KD: 70.49%
#Student accuracy with CE + CosineLoss: 70.12%
#Student accuracy with CE + RegressorMSE: 70.61%

RegressorMSE的方法会比 CosineLoss 效果更好,因为在教师和学生之间允许了一个可训练的层,这在学习方面给了学生模型一些回旋的余地,而不是迫使学生模型复制教师模型的表示。包括额外网络是基于提示蒸馏背后的理念。(Including the extra network is the idea behind hint-based distillation.)


http://www.ppmy.cn/ops/156015.html

相关文章

游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

arkui-x跨平台与android java联合开发

华为鸿蒙系统采用的是arkts,支持跨平台crossplatform 即前端为arkts,arkui-x框架,后端为其他的语言框架。 本篇示例后端采用的是java,android studio工程。 主要方式是前端鸿蒙完成界面元素、布局等效果,后面androi…

Unity3D仿星露谷物语开发26之创建场景控制管理器

1、目标 创建场景控制管理器,来加载和卸载场景,以实现场景之间的切换。 2、思路 Fade To Back是黑色的过渡场景,透明度逐渐变为1。 Fade To Transparent To Show Scene:黑色消失的过渡场景,透明度逐渐变为0. 事件触发…

使用 Docker 部署 pSQL 服务器 的教程

如何使用 Edu 邮箱申请 Azure 订阅并开通免费 VPS 使用 Edu 邮箱不仅可以申请 Azure 的免费订阅来开通 VPS,还可以免费使用 Adobe 和 Notion 等软件,极大地提高学习和工作的效率。如果您还没有 Edu 邮箱,可以参考在线笔记s3.tebi.io/notes-i…

Linux ifstat 命令使用详解

简介 Linux 中的 ifstat 命令用于显示网络接口统计信息,显示系统中每个网络接口的网络流量信息(如发送和接收的字节数或包数)。它提供了一种实时监视网络接口活动的方法,帮助系统管理员和用户诊断与网络相关的问题。 安装 Debi…

Windows图形界面(GUI)-QT-C/C++ - QT MDI Area

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 目录 一、概述 二、使用场景 1. 多文档编辑器 2. 多窗口应用程序 3. 多视图应用程序 三、常见样式 1. 子窗口管理 2. 布局管理 四、属性设置 1. 添加子窗口 2. 移除子窗口 3. 设置…

ASP.NET Core Filter

目录 什么是Filter? Exception Filter 实现 注意 ActionFilter 注意 案例:自动启用事务的筛选器 事务的使用 TransactionScopeFilter的使用 什么是Filter? 切面编程机制,在ASP.NET Core特定的位置执行我们自定义的代码。…

Mac电脑上好用的免费截图软件

在Mac电脑上,有许多免费且功能强大的截图软件可供选择。以下是几款备受推荐的免费截图工具: iShot 功能:iShot是一款免费的截图工具,支持多种截图方式,包括长截图、延时截图、滚动截图、窗口截图、区域截图等。此外&am…