参考:知识蒸馏Pytorch代码实战_哔哩哔哩_bilibili
import torch
import torch.nn.functional as F
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary
import numpy as np
import matplotlib.pyplot as plt"""
不同的温度下,知识蒸馏的效果不一样,就是标签之间的数值差距不一样logits = np.array(torch.randn(1, 10))
print(logits)
# 普通softmax T = 1
softmax = np.exp(logits) / sum(np.exp(logits))# 蒸馏温度 softmax T = 3, T越大,差距就越小
T = 3
softmax_3 = np.exp(logits / T) / sum(np.exp(logits / T))
plt.plot(softmax, label='softmax')
plt.legend()
plt.show()"""# 设置随机种子,便于复现
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True
# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='dataset/', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='dataset/', train=False, download=True, transform=transforms.ToTensor())# 生成dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 教师模型
class TeacherModel(nn.Module):def __init__(self, num_classes=10):super(TeacherModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 1200)self.fc2 = nn.Linear(1200, 1200)self.fc3 = nn.Linear(1200, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)x = self.relu(self.dropout(self.fc1(x)))x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x# 学生模型
class StudentModel(nn.Module):def __init__(self, num_classes=10):super(StudentModel, self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784, 5)self.fc2 = nn.Linear(5, 5)self.fc3 = nn.Linear(5, num_classes)self.dropout = nn.Dropout(p=0.5)def forward(self, x):x = x.view(-1, 784)# x = self.relu(self.dropout(self.fc1(x)))x = self.fc1(x)# x = self.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x"""
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:09<00:00, 192.13it/s]
Epoch:0, accuracy: 0.9405999779701233, loss: 0.2013
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 237.35it/s]
Epoch:1, accuracy: 0.9599999785423279, loss: 0.0651
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 233.87it/s]
Epoch:2, accuracy: 0.9693999886512756, loss: 0.2649
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 227.13it/s]
Epoch:3, accuracy: 0.9746999740600586, loss: 0.1803
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:08<00:00, 219.15it/s]
Epoch:4, accuracy: 0.9768999814987183, loss: 0.0242"""
def train_teacher(epochs, lr=1e-4):model = TeacherModel()model = model.to(device)# 查看模型参数量等信息# print(summary(model))criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 在训练集上训练权重for data, target in (train_loader):data = data.to(device)target = target.to(device)# 前向预测preds = model(data)loss = criterion(preds, target)# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 在测试集上评估model.eval()correct = 0samples = 0with torch.no_grad():for data, target in test_loader:data = data.to(device)target = target.to(device)preds = model(data)predictions = preds.argmax(dim=1)correct += (predictions == target).sum()samples += predictions.size(0)accuracy = (correct / samples).item()model.train()print(f'Epoch:{epoch}, accuracy: {accuracy}, loss: {loss:.4f}')# 保存模型的权重和结构torch.save(model, '/home/wangyp/Big_Model/Knowledge_Distillation/teacher_model.ckpt')"""
Epoch:0, accuracy: 0.7197999954223633, loss: 1.3583"""
def train_student(epochs, lr=1e-4):model = StudentModel()model = model.to(device)# 查看模型参数量等信息# print(summary(model))criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 在训练集上训练权重for data, target in (train_loader):data = data.to(device)target = target.to(device)# 前向预测preds = model(data)loss = criterion(preds, target)# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 在测试集上评估model.eval()correct = 0samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.argmax(dim=1)correct += (predictions == y).sum()samples += predictions.size(0)accuracy = (correct / samples).item()model.train()print(f'Epoch:{epoch}, accuracy: {accuracy}, loss: {loss:.4f}')"""
开始蒸馏
"""
def teacher_student(epochs=3, teacher_model=None, temp=7, lr=1e-4):# 准备好已经训练好的教师模型teacher_model.eval()# 准备没有训练过得学生模型student_model = StudentModel()student_model = student_model.to(device)student_model.train()# 蒸馏温度 temphard_loss = nn.CrossEntropyLoss()hard_loss_alpha = 0.3# 计算两个分布的相似度soft_loss = nn.KLDivLoss(reduction='batchmean')optimizer = optim.Adam(student_model.parameters(), lr=lr)for epoch in range(epochs):for data, target in (train_loader):data = data.to(device)target = target.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)student_loss = hard_loss(student_preds, target)# 计算蒸馏后的预测结果以及soft_loss, 两个softmax的分布差异大小distillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),F.softmax(teacher_preds / temp, dim=1))loss = hard_loss_alpha * student_loss + (1-hard_loss_alpha) * distillation_loss# print(f'Epoch:{epoch}, distillation_loss: {distillation_loss}, student_loss: {student_loss}')# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()# 测试集student_model.eval()correct=0samples=0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.argmax(dim=1)correct += (predictions == y).sum()samples += predictions.size(0)acc = (correct / samples).item()student_model.train()print(f'Epoch:{epoch}, accuracy: {acc}, loss: {loss:.4f}')passif __name__ == '__main__':# 训练教师模型# train_teacher(3)# train_student(1)# teacher_studentprint("开始加载模型。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。")# 加载模型的权重和结构teacher_model = torch.load('/home/wangyp/Big_Model/Knowledge_Distillation/teacher_model.ckpt')teacher_student(epochs=3, teacher_model=teacher_model, temp=7, lr=1e-4)