知识蒸馏教程 Knowledge Distillation Tutorial

news/2025/2/6 19:05:50/

来自于:Knowledge Distillation Tutorial
将大模型蒸馏为小模型,可以节省计算资源,加快推理过程,更高效的运行。

使用CIFAR-10数据集

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasetsdevice = "cuda" #CPU也可
transforms_cifar = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

定义模型

定义两个结构相似,只是在宽度和深度不同的模型。
教师模型DeepNN

# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):def __init__(self, num_classes=10):super(DeepNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

学生模型LightNN

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):def __init__(self, num_classes=10):super(LightNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

在这里插入图片描述

训练并测试模型

def train(model, train_loader, epochs, learning_rate, device):criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)model.train()for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:# inputs: A collection of batch_size images# labels: A vector of dimensionality batch_size with integers denoting class of each imageinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes# labels: The actual labels of the images. Vector of dimensionality batch_sizeloss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")def test(model, test_loader, device):model.to(device)model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Accuracy: {accuracy:.2f}%")return accuracy
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

DeepNN的参数量为1,186,986,准确率为75.98%。
LightNN的参数量为267,738,准确率为70.65%。

total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

知识蒸馏

教师模型和学生模型都输出了关于类别的概率分布,假设认为,经过训练的教师模型输出的softmax结果携带了更多的信息,有助于提高学生模型的准确率。例如,在默认情况下,汽车、火车、摩托车的对应的label为 [1,0,0],经过训练的教师模型输出结果可能是 [0.6,0.2,0.2],而对于汽车、狗、猫,教师模型输出的结果可能是[0.8,0.1,0.1],汽车和火车、摩托车要比狗、猫更相似。让学生模型学习到教师模型的这部分知识,就称为知识蒸馏

学生模型与真实值的损失使用交叉熵损失。
学生模型与教师模型的损失使用KL散度损失。

蒸馏过程中,冻结教师模型,只训练学生模型。

增加参数:

  • T:温度,温度控制着输出分布的平滑度。较大的 T 会导致更平滑的分布,因此较小的概率会得到更大的提升。
  • soft_target_loss_weight:学生模型与教师模型的损失的权重。
  • ce_loss_weight:学生模型与真实值的损失的权重。
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weightswith torch.no_grad():teacher_logits = teacher(inputs)# Forward pass with the student modelstudent_logits = student(inputs)#Soften the student logits by applying softmax first and log() secondsoft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")#Test Accuracy: 70.49%
#Teacher accuracy: 75.98%
#Student accuracy without teacher: 70.65%
#Student accuracy with CE + KD: 70.49%

CosineEmbeddingLoss

蒸馏的目标是让学生模型学习教师模型的知识,那么不只是学习最终的输出分布,也可以学习教师模型的内部表示hidden states。
可以比较两个模型的中间输出向量,使用CosineEmbeddingLoss。
在前面的模型中,教师模型flatten输出维度为2048,而学生模型为1024,因此在教师模型中加入额外池化层,让两个模型在同一个维度。

class ModifiedDeepNNCosine(nn.Module):def __init__(self, num_classes=10):super(ModifiedDeepNNCosine, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)flattened_conv_output = torch.flatten(x, 1)x = self.classifier(flattened_conv_output)flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)return x, flattened_conv_output_after_pooling# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):def __init__(self, num_classes=10):super(ModifiedLightNNCosine, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)flattened_conv_output = torch.flatten(x, 1)x = self.classifier(flattened_conv_output)return x, flattened_conv_output# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())

在这里插入图片描述
训练函数和测试函数也随之发生变化。

def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()cosine_loss = nn.CosineEmbeddingLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.to(device)student.to(device)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Forward pass with the teacher model and keep only the hidden representationwith torch.no_grad():_, teacher_hidden_representation = teacher(inputs)# Forward pass with the student modelstudent_logits, student_hidden_representation = student(inputs)# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test_multiple_outputs(model, test_loader, device):model.to(device)model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs, _ = model(inputs) # Disregard the second tensor of the tuple_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Accuracy: {accuracy:.2f}%")return accuracy# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
#Test Accuracy: 70.12%

Intermediate regressor run

对于高维度向量,余弦相似度通常比欧几里得距离效果更好,但我们处理的是每个具有 1024 个分量的向量,因此更难提取有意义的相似性。此外,正如我们所提到的,从理论上讲,推动教师和学生的隐藏表示相匹配是不被支持的。我们没有充分的理由应该追求这些向量的 1:1 匹配。
作者认为前面的蒸馏,学生模型和教师模型学习的是向量,即学习的是torch.flatten(x, 1),是一个向量,表达能力有限。因此选取 flatten 的前一层,学习卷积层的输出特征图。
教师模型的特征图shape为[128, 32, 8, 8],学生模型的特征图为[128, 16, 8, 8],需要添加一个卷积层,对齐维度。
在这里插入图片描述
在学生模型中加入了regressor层。

class ModifiedDeepNNRegressor(nn.Module):def __init__(self, num_classes=10):super(ModifiedDeepNNRegressor, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)conv_feature_map = xx = torch.flatten(x, 1)x = self.classifier(x)return x, conv_feature_mapclass ModifiedLightNNRegressor(nn.Module):def __init__(self, num_classes=10):super(ModifiedLightNNRegressor, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# Include an extra regressor (in our case linear)self.regressor = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1))self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)regressor_output = self.regressor(x)x = torch.flatten(x, 1)x = self.classifier(x)return x, regressor_output
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):ce_loss = nn.CrossEntropyLoss()mse_loss = nn.MSELoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.to(device)student.to(device)teacher.eval()  # Teacher set to evaluation modestudent.train() # Student to train modefor epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Again ignore teacher logitswith torch.no_grad():_, teacher_feature_map = teacher(inputs)# Forward pass with the student modelstudent_logits, regressor_feature_map = student(inputs)# Calculate the losshidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)# Calculate the true label losslabel_loss = ce_loss(student_logits, labels)# Weighted sum of the two lossesloss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_lossloss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")#Teacher accuracy: 75.98%
#Student accuracy without teacher: 70.65%
#Student accuracy with CE + KD: 70.49%
#Student accuracy with CE + CosineLoss: 70.12%
#Student accuracy with CE + RegressorMSE: 70.61%

RegressorMSE的方法会比 CosineLoss 效果更好,因为在教师和学生之间允许了一个可训练的层,这在学习方面给了学生模型一些回旋的余地,而不是迫使学生模型复制教师模型的表示。包括额外网络是基于提示蒸馏背后的理念。(Including the extra network is the idea behind hint-based distillation.)


http://www.ppmy.cn/news/1569874.html

相关文章

机器学习Machine Learning

机器学习(Machine Learning)是人工智能(AI)的一个分支,它使计算机系统能够利用数据和算法自动学习和改进其性能。 机器学习是让机器通过经验(数据)来做决策和预测。 机器学习已经广泛应用于许…

Servlet笔记(上)

简介 静态资源(客户端向服务端发出请求报文服务端直接生成响应报文发送给服务端)无需在程序运行时通过代码运行生成的资源,在程序运行之前就写好的资源 动态资源 (客户端向服务端发出请求报文服务端要先运行生成Java文件获得资源&#xff0…

Spring Boot 实例解析:从概念到代码

SpringBoot 简介: 简化 Spring 应用开发的一个框架整合 Spring 技术栈的一个大整合J2EE 开发的一站式解决方案优点:快速创建独立运行的 Spring 项目以及与主流框架集成使用嵌入式的 Servlet 容器,应用无需打成 war 包,内嵌 Tomcat…

应对现代电子商务的网络威胁—全面安全战略

在当今高度数字化的世界,电子商务企业面临的网络威胁比以往任何时候都更复杂和全球化。不再仅仅是简单的恶意软件或DDoS攻击,如今的威胁来源于复杂的黑客组织、精心设计的定向攻击,甚至是国家支持的网络犯罪活动。企业级电商平台,…

Excel 技巧22 - Ctrl+D 向下复制(★★),复制同间距图形

本文讲Excel中CtrlD 向下复制的用法。 这个是我特别喜欢和常用的功能,操作简单,功能强大。 1,CtrlD向下复制 1-1,单个单元格复制 最为常用的就是一个单元格的,就像下面这样的,也不用选中, 就…

MongoDB学习笔记-解析jsonCommand内容

如果需要屏蔽其他项目对MongoDB的直接访问操作&#xff0c;统一由一个入口访问操作MongoDB&#xff0c;可以考虑直接传入jsonCommand语句解析执行。 相关依赖包 <!-- SpringBootDataMongodb依赖包 --> <dependency><groupId>org.springframework.boot</…

【Golang学习之旅】Go 语言数据类型详解(string、slice、map等)

文章目录 前言1. Go语言数据类型概览2. Go语言基本数据类型2.1 整型&#xff08;int&#xff0c;uint&#xff0c;float&#xff09;2.2 布尔类型&#xff08;bool&#xff09;2.3 字符串&#xff08;string&#xff09; 3. Go 语言复合数据类型3.1 数组&#xff08;Array&…

CNN的各种知识点(一):卷积神经网络CNN通道数的理解!

卷积神经网络CNN通道数的理解&#xff01; 通道数的核心概念解析1. 通道数的本质 2. 单张灰度图的处理示例&#xff1a; 3. 批量输入的处理通道与批次的关系&#xff1a; 4. RGB三通道输入的处理计算过程&#xff1a;示例&#xff1a; 5. 通道数的实际意义6. 可视化理解(1) 单通…