基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

ops/2025/2/16 1:50:27/

 第一步:准备数据

5种中草药数据:self.class_indict = ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"]

,总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个EfficientNetV2网络,其原理介绍如下:

        该网络主要使用训练感知神经结构搜索缩放的组合;在EfficientNetV1的基础上,引入了Fused-MBConv到搜索空间中;引入渐进式学习策略自适应正则强度调整机制使得训练更快;进一步关注模型的推理速度训练速度

与EfficientV1相比,主要有以下不同:

  1. V2中除了使用MBConv模块外,还使用了Fused-MBConv模块
  2. V2中会使用较小的expansion ratio,在V1中基本都是6。这样的好处是能够减少内存访问开销
  3. V2中更偏向使用更小的kernel_size(3 x 3),在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的,所以需要堆叠更多的层结构以增加感受野
  4. 移除了V1中最优一个步距为1的stage

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = {"s": [300, 384],  # train_size, val_size"m": [384, 480],"l": [384, 480]}num_model = "s"data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),transforms.CenterCrop(img_size[num_model][1]),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\ChineseMedicine")# download model weights# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ  密码: 5gu1parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

有问题可以私信或者留言,有问必答


http://www.ppmy.cn/ops/45555.html

相关文章

Mac连接虚拟机(Linux系统)

1.确定虚拟机的IP地址 ifconfig //终端命令,查询ip地址 sudo apt install net-tools 安装完成后再次执行 ifconfig: 2.安装SSH(加密远程登录协议) (1).安装OpenSSH服务器软件包: sudo apt-get install openssh-ser…

软件架构设计属性之四:性能属性浅析

文章目录 引言一、性能属性的内涵与重要性1.1 性能属性的内涵1.2 性能属性的重要性1.3 性能属性的实践应用 二、性能属性的关键影响因素2.1 系统结构2.2 数据处理2.3 并发处理2.4 硬件和资源2.5 软件和框架2.6 监控和优化 三、性能属性的优化策略与实践3.1 优先级队列与资源调度…

AI+算力:科技新时代的创新引擎

随着人工智能(AI)技术的飞速发展,“AI算力”的结合应用已成为科技行业的热点话题,甚至诞生出“AI算力最强龙头“的网络热门等式。这个组合不仅可以提高计算效率,还可以为各行各业带来更强大的数据处理和分析能力&#…

四川汇聚荣聚荣科技有限公司在市场评价好吗?

随着科技行业的迅猛发展,越来越多的科技公司如雨后春笋般涌现,其中不乏一些优秀的企业。四川汇聚荣聚荣科技有限公司便是其中的一员。那么,这家公司在市场上的评价如何呢?接下来,我们将从四个方面进行详细的阐述。 一、公司概况四…

代码随想录算法训练营day34 | 455.分发饼干、376. 摆动序列、53. 最大子序和

理论基础 贪心的本质是选择每一阶段的局部最优,从而达到全局最优。 刷题或者面试的时候,手动模拟一下感觉可以局部最优推出整体最优,而且想不到反例,那么就试一试贪心。 455.分发饼干 result和j变化一致,可以去除一…

JS的基本内容

JS中的六中数据类型字符型,数值型,布尔型,Null,undefined和对象Object:符合数据类型,对象是属性和方法的集合甚至是另一种类型的对象。 基本数据类型:数值、字符串、null、undefined、布尔&…

代码随想录算法训练营第四十四天|km46. 携带研究材料、 416. 分割等和子集

代码随想录算法训练营第四十四天 km46. 携带研究材料 题目链接:km46. 携带研究材料 确定dp数组以及下标的含义:i的含义是物品编号从0到i,j的含义是当前背包的最大容量,dp[i][j]背包内物品的总价值确定递推公式:背包…

动手学深度学习(Pytorch版)代码实践-深度学习基础-01基础函数的使用

01基础函数的使用 主要内容 张量操作:创建和操作张量,包括重塑、填充、逐元素操作等。数据处理:使用pandas加载和处理数据,包括处理缺失值和进行one-hot编码。线性代数:包括矩阵运算、求和、均值、点积和各种范数计算…