【深度学习】(11)--迁移学习

news/2024/9/29 14:40:47/

文章目录

  • 迁移学习
    • 一、迁移学习步骤
    • 二、以残差网络为例
      • 1. 导入模型
      • 2. 冻结参数
      • 3. 修改全连接层
      • 4. 创建数据集的类
      • 5. 处理数据
      • 6. 装配设备
      • 7. 建立模型
      • 8. 训练模型
    • 三、完整代码展示
  • 总结

迁移学习

迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

一、迁移学习步骤

  1. 选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
  2. 冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
  3. 在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
  4. 微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
  5. 评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、以残差网络为例

1. 导入模型

torchvision中导入模型,库中已经存放好了大量模型框架。

import torchvision.models as models
resnet_model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
# weights = models.ResNet18_Weights.DEFAULT表示在使用ImageNet数据集上预先训练好的权重来初始化模型参数

2. 冻结参数

冻结参数,使得在反向传播过程中,不要在计算他们的梯度,减少计算量。

for param in resnet_model.parameters():print(param)# 模型所有的参数(权重和偏置项)的requires_grad属性设置为False,冻结所有模型参数# 使得在反向传播过程中,不要在计算他们的梯度,减少计算量param.requires_grad = False

3. 修改全连接层

因为原本模型中的输出有1000种特征,而我们现在训练的数据仅有20种特征,需要需改输出:

# 获取模型原输入的特征个数
in_features = resnet_model.fc.in_features
# 创建一个全连接层(将原全连接层覆盖),输入特征为in_features,输出为20
resnet_model.fc = nn.Linear(in_features,20)params_to_update = [] # 保存需要训练的参数,仅训练修改的全连接层参数
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

4. 创建数据集的类

残差模型的传入数据大小为(224),所以要对数据进行裁剪
data_transforms = {'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45), # 随机旋转,-45到45度之间随便选transforms.CenterCrop(224), # 从中心开始剪裁transforms.RandomHorizontalFlip(p=0.5),# 随机水平反转,设定一个概率transforms.RandomVerticalFlip(p=0.5),# 随机垂直反转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),# 参数1亮度,参数2对比度,参数3饱和度,参数4色相transforms.RandomGrayscale(p=0.1),# 转化为灰度图transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) # 标准化:均值,标准差(统一的)]),'valid':transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
}

5. 处理数据

划分数据中的特征与标签:

"""-----处理数据-----"""
class food_dataset(Dataset):def __init__(self,file_path,transform = None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path) # 特征self.labels.append(label)# 标签def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label
training_data = food_dataset(file_path='trainda.txt',transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt',transform=data_transforms['valid'])train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)

6. 装配设备

"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

7. 建立模型

"""-----建立模型-----"""
model = resnet_model.to(device)

8. 训练模型

"""-----训练集-----"""
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num =1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()  # 获取损失值if batch_size_num %20 == 0:  # 每200次迭代打印一次损失print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:best_acc = correct
"""-----损失函数-----"""
loss_fn = nn.CrossEntropyLoss()"""-----优化器-----"""
optimizer = torch.optim.Adam(params_to_update,lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)epochs = 100
acc_s = []
loss_s = []
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)scheduler.step()test(test_dataloader,model,loss_fn)
print('最优训练结果:',best_acc)

结果:

在这里插入图片描述

三、完整代码展示

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms"""将resnet18模型迁移到食物分类项目中"""#残差网络是固定的网络结构,不需要自己来类定义
resnet_model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
# weights = models.ResNet18_Weights.DEFAULT表示在使用ImageNet数据集上预先训练好的权重来初始化模型参数
for param in resnet_model.parameters():print(param)# 模型所有的参数(权重和偏置项)的requires_grad属性设置为False,冻结所有模型参数# 使得在反向传播过程中,不要在计算他们的梯度,减少计算量param.requires_grad = False"""-----修改残差模型中的全连接层-----"""# 因为原本模型中的输出有1000种特征,而我们现在训练的数据仅有20种特征,需要需改输出
# 获取模型原输入的特征个数
in_features = resnet_model.fc.in_features
# 创建一个全连接层(将原全连接层覆盖),输入特征为in_features,输出为20
resnet_model.fc = nn.Linear(in_features,20)params_to_update = [] # 保存需要训练的参数,进训练修改的全连接层
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)"""-----创建数据集的类-----"""# 残差模型的传入数据大小为(224),所以要对数据进行裁剪
data_transforms = {'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45), # 随机旋转,-45到45度之间随便选transforms.CenterCrop(224), # 从中心开始剪裁transforms.RandomHorizontalFlip(p=0.5),# 随机水平反转,设定一个概率transforms.RandomVerticalFlip(p=0.5),# 随机垂直反转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),# 参数1亮度,参数2对比度,参数3饱和度,参数4色相transforms.RandomGrayscale(p=0.1),# 转化为灰度图transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) # 标准化:均值,标准差(统一的)]),'valid':transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
}"""-----处理数据-----"""
class food_dataset(Dataset):def __init__(self,file_path,transform = None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")"""-----数据处理-----"""
training_data = food_dataset(file_path='trainda.txt',transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt',transform=data_transforms['valid'])train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)"""-----建立模型-----"""
model = resnet_model.to(device)"""-----损失函数-----"""
loss_fn = nn.CrossEntropyLoss()"""-----优化器-----"""
optimizer = torch.optim.Adam(params_to_update,lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)"""-----训练集-----"""
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num =1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()  # 获取损失值if batch_size_num %20 == 0:  # 每200次迭代打印一次损失print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:best_acc = correct"""-----训练模型-----"""
epochs = 100
acc_s = []
loss_s = []
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)scheduler.step()test(test_dataloader,model,loss_fn)
print('最优训练结果:',best_acc)

总结

本篇介绍了:

  1. 如何进行迁移学习
  2. 对迁移模型进行微调:
    1. 微调全连接层
    2. 微调卷积层(本篇未写),原理相同,可自行尝试
  3. 注意:原本的模型参数务必要冻结住,那是已经调好的,可以节省计算时间。仅需要调整修改部分的参数。

http://www.ppmy.cn/news/1531953.html

相关文章

大厂面试真题-说一下Mybatis的缓存

首先看一下原理图 Mybatis提供了两种缓存机制:一级缓存(L1 Cache)和二级缓存(L2 Cache),旨在提高数据库查询的性能,减少数据库的访问次数。注意查询的顺序是先二级缓存,再一级缓存。…

软件架构思考

title: 软件架构思考 date: 2019-03-01 14:07:48 tags: [tips] categories: tips 架构是对工程整体结构与组件的抽象描述,是软件工程的基础骨架。架构在工程层面不分领域,且思想是通用的。引用维基百科对于软件架构的定义: 软件体系结构是构…

PHP中如何使用三元条件运算符

在PHP中,三元条件运算符(也称为三元运算符或条件运算符)是一种非常紧凑的写法,用于根据条件表达式的真假值来返回两个值中的一个。尽管你的请求要求5000字的内容,但实际上这个主题相当直接且简短,因为它基于…

一带一路区块链赛项样题解析(中)

一带一路区块链赛项样题解析 (模块二) 标题任务一 按要求完成智能合约开发 1、学籍信息合约(Roll)接口编码(6分) (1)编写学籍信息合约中的RollInfo 实体接口,完成RollInfo实体通用数据的初始化,实现可追溯的学籍信息上链功能;(2分) // SPDX-License-Identifie…

电脑自带dll修复在哪里,dll丢失的6种解决方法总结

在现代科技日新月异的时代,电脑已经成为我们生活中不可或缺的一部分。然而,在使用电脑的过程中,我们常常会遇到一些常见的问题,其中之一就是dll文件丢失或损坏。当这些dll文件丢失或损坏时,可能会导致某些应用程序无法…

Promise从入门到提高实战教程

一、Promiss 介绍 Promise是一门新的技术(ES6规范),是JS中进行异步编程的新解决方案。 从语法上说,他是一个构造函数,Promise对象用来封装一个异步操作并可以获取成功/失败的结果值。(也就是包裹一个异步操作)创造Promise实例时,必须传入一个函数作为参数,一般传递一…

【FastAPI】使用 SQLAlchemy 和 FastAPI 实现 PostgreSQL 中的 JSON 数据 CRUD 操作

在现代 web 开发中,处理 JSON 数据变得越来越普遍。本文将指导你如何使用 FastAPI 和 SQLAlchemy 实现对 PostgreSQL 数据库中 JSON 数据的增删改查(CRUD)操作。 环境准备 首先,确保你已经安装了所需的库。在终端中运行以下命令…

RPA助力企业办公流程自动化:真实应用案例展示

在当今快速变化的商业环境中,企业面临着前所未有的挑战和机遇。数字化转型已成为企业提升竞争力、优化运营效率和增强客户体验的关键策略。RPA数字员工作为这一转型过程中的重要工具,正在帮助企业实现办公流程的自动化,从而加速数字化转型的步…