老师学生蒸馏模型实战

news/2024/11/16 3:49:29/

 参考:知识蒸馏Pytorch代码实战_哔哩哔哩_bilibili

import torch
import torch.nn.functional as F
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary
import numpy as np
import matplotlib.pyplot as plt"""
不同的温度下,知识蒸馏的效果不一样,就是标签之间的数值差距不一样logits = np.array(torch.randn(1, 10))
print(logits)
# 普通softmax T = 1
softmax = np.exp(logits) / sum(np.exp(logits))# 蒸馏温度 softmax T = 3, T越大,差距就越小
T = 3
softmax_3 = np.exp(logits / T) / sum(np.exp(logits / T))
plt.plot(softmax, label='softmax')
plt.legend()
plt.show()"""# 设置随机种子,便于复现
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True
# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='dataset/', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='dataset/', train=False, download=True, transform=transforms.ToTensor())# 生成dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 教师模型
class TeacherModel(nn.Module):def __init__(self, num_classes=10):super(TeacherModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x# 学生模型
class StudentModel(nn.Module):def __init__(self, num_classes=10):super(StudentModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 5)self.fc2 = nn.Linear(5, 5)self.fc3 = nn.Linear(5, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)# x = self.relu(self.dropout(self.fc1(x)))x = self.fc1(x)# x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x"""
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:09<00:00, 192.13it/s]
Epoch:0, accuracy: 0.9405999779701233, loss: 0.2013
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 237.35it/s]
Epoch:1, accuracy: 0.9599999785423279, loss: 0.0651
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 233.87it/s]
Epoch:2, accuracy: 0.9693999886512756, loss: 0.2649
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 227.13it/s]
Epoch:3, accuracy: 0.9746999740600586, loss: 0.1803
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 219.15it/s]
Epoch:4, accuracy: 0.9768999814987183, loss: 0.0242"""
def train_teacher(epochs,  lr=1e-4):model = TeacherModel()model = model.to(device)# 查看模型参数量等信息# print(summary(model))criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 在训练集上训练权重for data, target in (train_loader):data = data.to(device)target = target.to(device)# 前向预测preds = model(data)loss = criterion(preds, target)# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 在测试集上评估model.eval()correct = 0samples = 0with torch.no_grad():for data, target in test_loader:data = data.to(device)target = target.to(device)preds = model(data)predictions = preds.argmax(dim=1)correct += (predictions == target).sum()samples += predictions.size(0)accuracy = (correct / samples).item()model.train()print(f'Epoch:{epoch}, accuracy: {accuracy}, loss: {loss:.4f}')# 保存模型的权重和结构torch.save(model, '/home/wangyp/Big_Model/Knowledge_Distillation/teacher_model.ckpt')"""
Epoch:0, accuracy: 0.7197999954223633, loss: 1.3583"""
def train_student(epochs,  lr=1e-4):model = StudentModel()model = model.to(device)# 查看模型参数量等信息# print(summary(model))criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 在训练集上训练权重for data, target in (train_loader):data = data.to(device)target = target.to(device)# 前向预测preds = model(data)loss = criterion(preds, target)# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 在测试集上评估model.eval()correct = 0samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.argmax(dim=1)correct += (predictions == y).sum()samples += predictions.size(0)accuracy = (correct / samples).item()model.train()print(f'Epoch:{epoch}, accuracy: {accuracy}, loss: {loss:.4f}')"""
开始蒸馏
"""
def teacher_student(epochs=3, teacher_model=None, temp=7,  lr=1e-4):# 准备好已经训练好的教师模型teacher_model.eval()# 准备没有训练过得学生模型student_model = StudentModel()student_model = student_model.to(device)student_model.train()# 蒸馏温度 temphard_loss = nn.CrossEntropyLoss()hard_loss_alpha = 0.3# 计算两个分布的相似度soft_loss = nn.KLDivLoss(reduction='batchmean')optimizer = optim.Adam(student_model.parameters(), lr=lr)for epoch in range(epochs):for data, target in (train_loader):data = data.to(device)target = target.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)student_loss = hard_loss(student_preds, target)# 计算蒸馏后的预测结果以及soft_loss, 两个softmax的分布差异大小distillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = hard_loss_alpha * student_loss + (1-hard_loss_alpha) * distillation_loss# print(f'Epoch:{epoch}, distillation_loss: {distillation_loss}, student_loss: {student_loss}')# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 测试集student_model.eval()correct=0samples=0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.argmax(dim=1)correct += (predictions == y).sum()samples += predictions.size(0)acc = (correct / samples).item()student_model.train()print(f'Epoch:{epoch}, accuracy: {acc}, loss: {loss:.4f}')passif __name__ == '__main__':# 训练教师模型# train_teacher(3)# train_student(1)# teacher_studentprint("开始加载模型。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。")# 加载模型的权重和结构teacher_model = torch.load('/home/wangyp/Big_Model/Knowledge_Distillation/teacher_model.ckpt')teacher_student(epochs=3, teacher_model=teacher_model, temp=7, lr=1e-4)


http://www.ppmy.cn/news/1423128.html

相关文章

ChatGPT辅助下的论文写作之道

ChatGPT无限次数:点击直达 ChatGPT辅助下的论文写作之道 在当今信息爆炸的时代&#xff0c;学术论文写作是每个研究者和学生不可或缺的技能。然而&#xff0c;对于许多人来说&#xff0c;写作是一个具有挑战性和耗时的过程。幸运的是&#xff0c;随着人工智能技术的不断进步&a…

Linux 1.文件编程(dup、dup2)

重定向 重定向是什么&#xff1f;dupdup2 重定向是什么&#xff1f; 进程在最开始运行的时候&#xff0c;首先打开了三个文件&#xff0c;分别是标准输入流、标准输出流、标准错误输出流。证明的时候我是把标准输出留给关闭了&#xff0c;然后紧接着创建的文件就会占用已关闭的…

上班记事备忘的软件有什么 工作记事本软件

在繁忙的工作节奏中&#xff0c;我常常感到自己的记忆力似乎不够用。开会时&#xff0c;脑海中灵光一闪的想法&#xff0c;转眼就忘得一干二净&#xff1b;与客户的约定&#xff0c;总是在忙碌中错过。记性不好&#xff0c;不仅影响了工作效率&#xff0c;更让我在同事和客户面…

c++ 拷贝构造函数 简单实验

1.概要 什么时候调用拷贝构造&#xff0c;就是有对象创建的时候&#xff0c;就会调用拷贝构造&#xff0c;无论对象是构造&#xff08;类&#xff08;来源&#xff09;&#xff09;还是赋值&#xff08;对象&#xff09;都会调用拷贝构造。 赋值函数调用的时机是两个对象都已经…

想要私域流量翻倍?这四个关键要素绝对不能错过!

在当今“得流量者得天下”的时代&#xff0c;拥有稳定且高质量的私域流量对于企业或是个人来说至关重要。然而&#xff0c;如何才能实现私域流量翻倍呢&#xff1f; 今天就给大家分享私域流量的四个关键要素&#xff0c;让大家都能实现私域流量的快速增长。 第一个&#xff1…

ChatGPT之道:巧用写作技巧

ChatGPT无限次数:点击直达 ChatGPT之道&#xff1a;巧用写作技巧 在当今快节奏的社会中&#xff0c;写作是一项非常重要的技能&#xff0c;不仅可以帮助我们有效表达思想&#xff0c;还可以提升个人能力和吸引力。而借助人工智能技术&#xff0c;如OpenAI推出的ChatGPT&#x…

遇事不决 量子力学?

文章目录 引入量子力学产生的必然性量子力学名称的由来粒子&#xff1f;波&#xff1f;波粒二象性测不准原理 &#xff08;不确定原理&#xff09;叠加态原理 量子纠缠态叠加量子纠缠量子纠缠实验 逻辑判断&#xff0c;量子力学到底完善吗观测量子纠缠&#xff1f;那我们宏观世…

Java面试:MySQL面试题汇总

1.说一下 MySQL 执行一条查询语句的内部执行过程&#xff1f; 答&#xff1a;MySQL 执行一条查询的流程如下&#xff1a; 客户端先通过连接器连接到 MySQL 服务器&#xff1b;连接器权限验证通过之后&#xff0c;先查询是否有查询缓存&#xff0c;如果有缓存&#xff08;之前…