Pytorch学习---基于经典网络架构ResNet训练花卉图像分类模型

news/2024/9/21 17:50:40/

基于经典网络架构训练图像分类模型

导包


import copy
import json
import time
import torch
from torch import nn
import torch.optim as optim
import torchvision
import os
from torchvision import transforms, models, datasets
import numpy as np
import matplotlib.pyplot as plt
import ssl

冻结中间层的所有参数,只训练最后输出全连接层

def set_parameter_requeires_grad(model,feature_extracting):"""set_parameter_requires_grad 函数的作用是根据 feature_extracting 参数的值来决定是否冻结模型的参数。当用于特征提取时,它会阻止预训练模型的参数在训练过程中被更新,从而保留预训练模型的特征提取能力。当用于微调时,它不会修改参数的 requires_grad 属性,从而允许所有参数被更新。"""# 该函数会遍历模型的所有参数,并将它们的 requires_grad 属性设置为 Falseif feature_extracting:for param in model.parameters():param.requires_grad = False

是否用gpu进行训练

def get_device() -> torch.device:"""确定并返回用于训练的设备(CPU 或 GPU)"""train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('你的gpu不可用,尝试在cpu训练')else:print('gpu可用,训练中在gpu中进行')#torch.device() 是一个用于创建设备对象的构造函数,它可以指定张量和模型运行在 CPU 还是 GPU 上。device = torch.device('cuda:0'if torch.cuda.is_available() else "cpu")return device

选择迁移模型,这里选择resnet残差神经网络,不同模型的初始化方法稍微有点不同

def initialize_model(model_name, num_classes,feature_extract,use_pretrained = True):"""用于初始化一个特定的深度学习模型( ResNet-152),并将它用于图像分类任务。:param model_name:要学习的模型名称:param num_classes:指定分类任务的目标类别数量:param feature_extrace:是否进行特征提取(冻结训练层):param use_pretrained:是否使用预训练的模型,如果为TRUE,则使用ImageNet预训练的权重初始化模型,false则随机初始化权重:return:"""if model_name == 'resnet':model_ft = models.resnet152(pretrained=use_pretrained)set_parameter_requeires_grad(model_ft,feature_extract)# model_ft.fc 是指模型的最后一个全连接层(分类层)。.in_features 是一个属性,它表示这个全连接层的输入特征数量。num_ftrs = model_ft.fc.in_features# 将原模型的最后一层替换为一个适合当前任务的分类层,输出节点数量为 num_classes。新分类层由一个线性层和一个 LogSoftmax 层组成,用于输出分类概率。model_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),nn.LogSoftmax(dim=1))input_size = 224else:print('无效模型,不存在!')exit()return model_ft, input_size

用于启动花卉分类任务的准备工作

def flower_start():# 模型初始化,获取设备,将模型放置到设备上进行训练model_ft, input_size = initialize_model("resnet", 102, feature_extract=True, use_pretrained=True)device = get_device()model_ft = model_ft.to(device)# 获取需要更新的参数,并提取params = model_ft.named_parameters()print('需要学习的参数有:')params_need_update = []for param_name, param in params:if param.requires_grad:params_need_update.append(param)print(param_name)# 数据路径data_dir = './flower_data'train_dir = data_dir+'/train'valid_dir = data_dir+'/valid'# 将分类后的编号和对应的名字找到with open('cat_to_name.json')as f:cat_to_name = json.load(f)# 数据增强变换"""torchvision.transforms 提供了一系列用于图像预处理的功能,包括图像增强、转换和标准化等操作。这些变换可以应用于图像数据,以增强模型的泛化能力或改善训练效果。transforms.Compose 是一个容器类,用于将多个变换组合在一起形成一个变换序列。这样可以方便地定义一系列变换操作,并按照顺序依次应用到图像数据上。"""data_transforms = {'train': transforms.Compose([transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选transforms.CenterCrop(224),  # 从中心开始裁剪,只得到一张图片transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 概率为0.5transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),  # 概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),# 迁移学习,用别人的均值和标准差transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,标准差]),'valid': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),# 预处理必须和训练集一致transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}batch_size = 8# 将变换后的图片用字典保存"""os.path.join(data_dir, x):将 data_dir(数据目录)与 'train' 或 'valid' 字符串拼接起来,形成训练集和验证集的完整路径。data_transforms[x]:根据 'train' 或 'valid' 选择相应的数据变换。datasets.ImageFolder:PyTorch 中的一个类,用于加载文件夹结构中的图像数据集。该类会自动根据文件夹结构生成类标签,并应用指定的变换。"""image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train', 'valid']}# print(image_datasets)# 批量处理dataloaders = {x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train', 'valid']}dataset_sizes = {x:len(image_datasets[x]) for x in ['train','valid']}print(dataset_sizes)# 样本数据的标签class_names = image_datasets['train'].classesprint(class_names)# 画出预处理好的图像fig = plt.figure(figsize=(20,12))columns, rows = 4,2dataiter = iter(dataloaders['valid'])inputs, classes = next(dataiter)for idx in range(columns * rows):ax = fig.add_subplot(rows, columns, idx + 1, xticks=[], yticks=[])# classes为索引,class_name里为实际label,再去拿到对应的花名ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])img = transforms.ToPILImage()(inputs[idx])plt.imshow(img)plt.show()# 优化器设置optimizer_ft = optim.Adam(params_need_update,lr=1e-2)# 设置学习率衰减,每7个训练过程衰减为原来的1/10scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=1/10)# 这里不再使用交叉熵损失函数,因为模型中最后一层是logsoftmax(),已经是对数了,所有直接用nllloss输入已经经过 LogSoftmax() 处理的对数概率分布criterion = nn.NLLLoss()filename = 'wz.pth'model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = wz_model_train(model_ft,dataloaders,criterion,optimizer_ft,scheduler,filename,device)for param in model_ft.parameters():param.requires_grad = True

训练模型函数

def wz_model_train(model,dataloaders,criterion,optimizer,scheduler,filename:str,device:torch.device,num_epochs=2,is_inception=False):"""训练和验证模型:通过迭代数据集来进行训练和验证。保存最佳模型:记录并保存验证集上表现最好的模型。记录训练历史:记录每个epoch的训练和验证损失及准确率。学习率更新:使用学习率调度器更新学习率。性能报告:打印每个epoch的训练时间和性能指标。:param model:训练的模型:param dataloaders:数据加载器:param criterion:损失函数:param optimizer:优化器:param scheduler:学习率调度器:param filename:保存模型的文件名:param device:使用的设备:param num_epochs:训练轮数:param is_inception:是否使用inception网络:return:"""start_time = time.time() # 记录训练开始时间best_acc = 0  # 记录训练最好准确率best_model_weights = copy.deepcopy(model.state_dict())  # 记录最好训练的模型参数model.to(device)# 保存损失和准确率数据val_acc_history = []train_acc_history = []train_losses = []valid_losses = []# 记录每个epoch的学习率LRs = [optimizer.param_groups[0]['lr']]for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs-1}')print('----------------------------')# 训练和验证for phase in ['train', 'valid']:if phase=='train':model.train()else:model.eval()running_loss = 0.0 # 累计损失running_corrects = 0 # 累计正确预测的数量for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() # 梯度清零# 在训练模式下开启梯度计算,在评估模式下关闭梯度计算。with torch.set_grad_enabled(phase=='train'):# inception网络有一个辅助输出,和主输出加权取损失值,这样可以增加稳定性"""辅助输出是指在网络中间某一层产生的额外预测结果。这种设计主要用于提高模型的训练稳定性,并在一定程度上防止过拟合。辅助输出可以为网络提供额外的监督信号,帮助模型更快地收敛。在 Inception 网络中,辅助分类器(Auxiliary Classifier)通常位于网络的中间部分。这些辅助分类器可以提供额外的监督信号,帮助网络更好地学习特征。辅助分类器通常包括全局平均池化、全连接层和 Softmax 层,以便产生分类结果。"""if is_inception and phase=='train':outputs, aux_outputs = model(inputs)loss1 = criterion(outputs, labels)loss2 = criterion(aux_outputs, labels)loss = loss1+0.4*loss2else:  # 这里的训练不需要开启inception# print('没有开启inception')outputs = model(inputs)loss = criterion(outputs,labels)# 不要概率最大值本身,要的是他的标签_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算批量的loss和正确预测数量running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds==labels.data)# 计算平均损失和损失率epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# 打印一个epoch的时间,这个时间可以是训练阶段的,也可以是验证阶段的time_elapsed = time.time()-start_timeprint('本次epoch模型已经跑了{:.0f}分 {:.0f}秒'.format(time_elapsed//60,time_elapsed%60))print('{}的损失loss是:{:.4f},准确率是{:.4f}'.format(phase,epoch_loss,epoch_acc))# 得到最好的那次模型if phase=='valid' and epoch_acc>best_acc:best_acc = epoch_accbest_model_weights = copy.deepcopy(model.state_dict())state = {'state_dict':model.state_dict(),'best_acc':best_acc,'optimizer': optimizer.state_dict(),}torch.save(state,filename)if phase=='valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step()  # 根据验证集来调整学习率if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])time_elapsed = time.time()-start_timeprint('训练在{:.0f}分{:.0f}秒完成'.format(time_elapsed//60,time_elapsed%60))print('最好的精确值:{:4f}'.format(best_acc))# 将最好的训练一次当最终值model.load_state_dict(best_model_weights)return model, val_acc_history, train_acc_history,valid_losses,train_losses, LRs
ssl._create_default_https_context = ssl._create_unverified_context
flower_start()
gpu可用,训练中在gpu中进行
需要学习的参数有:
fc.0.weight
fc.0.bias
{'train': 6552, 'valid': 818}
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Epoch 0/1
----------------------------
本次epoch模型已经跑了3分 39秒
train的损失loss是:10.3695,准确率是0.3219
本次epoch模型已经跑了3分 60秒
valid的损失loss是:9.9343,准确率是0.4364
Optimizer learning rate : 0.0100000
Epoch 1/1
----------------------------
本次epoch模型已经跑了7分 46秒
train的损失loss是:8.1607,准确率是0.4899
本次epoch模型已经跑了8分 8秒
valid的损失loss是:15.6906,准确率是0.3619
Optimizer learning rate : 0.0100000
训练在8分8秒完成
最好的精确值:0.436430

.4364
Optimizer learning rate : 0.0100000
Epoch 1/1
----------------------------
本次epoch模型已经跑了7分 46秒
train的损失loss是:8.1607,准确率是0.4899
本次epoch模型已经跑了8分 8秒
valid的损失loss是:15.6906,准确率是0.3619
Optimizer learning rate : 0.0100000
训练在8分8秒完成
最好的精确值:0.436430



http://www.ppmy.cn/news/1528482.html

相关文章

24/9/16 算法笔记 评估模型

评估机器学习模型的性能是一个关键步骤,它可以帮助我们了解模型在实际应用中的表现。以下是一些常用的评估模型的方法: 准确率(Accuracy): 最常见的评估指标,表示正确预测的样本数占总样本数的比例。 精确度…

Nginx 常用功能

Nginx 四层访问控制 Nginx 中的访问控制功能基于 ngx_http_access_module 模块实现,可以通过匹配客户端源 IP 地址进行 限制 该模块是默认模块,在使用 apt/yum 安装的环境中默认存在,如果想要禁用,需要自行编译,然后…

Java 入门指南:JVM(Java虚拟机)垃圾回收机制 —— 垃圾回收算法

文章目录 垃圾回收机制垃圾判断算法引用计数法可达性分析算法虚拟机栈中的引用(方法的参数、局部变量等)本地方法栈中 JNI 的引用类静态变量运行时常量池中的常量 垃圾收集算法Mark-Sweep(标记-清除)算法Copying(标记-…

代码随想录算法训练营day37

1.携带研究材料 1.1 题目 52. 携带研究材料&#xff08;第七期模拟笔试&#xff09; 1.2 题解 #include <iostream> #include <functional> #include <vector> using namespace std;int main() {//输入相关信息int classes, cabaity;cin >> classe…

OpenHarmony(鸿蒙南向开发)——轻量系统STM32F407芯片移植案例

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——轻量和小型系统三方库移植指南…

细胞分裂检测系统源码分享

细胞分裂检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

Imagen架构详解:理解其背后的技术与创新

Imagen架构详解&#xff1a;理解其背后的技术与创新 引言 近年来&#xff0c;生成式人工智能技术取得了飞速发展&#xff0c;特别是在图像生成领域。作为这一领域的重要创新之一&#xff0c;Imagen 是由谷歌开发的一种基于文本生成图像的模型。它在生成高质量、逼真的图像方面…

OpenCV calcHist()函数及其用法详解

OpenCV calcHist()函数原型共有三个&#xff0c;如下&#xff1a; 该函数计算一个或多个数组的直方图。用于递增直方图箱的元组的元素取自同一位置的相应输入数组。 函数参数&#xff1a; images 源&#xff08;图像&#xff09;数组。它们都应具有相同的深度、CV_8U、CV_16U…