基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)

news/2024/12/22 14:15:51/

目录

一、项目界面

二、代码实现

1、网络代码

2、训练代码

3、评估代码

4、结果显示

三、项目代码


一、项目界面

二、代码实现

1、网络代码

该网络基于残差模型修改

python">import torch
import torch.nn as nn
import torchvision.models as modelsclass resnet18(nn.Module):def __init__(self, num_classes=5, pretrained=False):super(resnet18, self).__init__()# 加载ResNet-18模型self.model = models.resnet18(pretrained=pretrained)# print(self.model)# 更改全连接层以输出自定义类别数量self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)def forward(self, x):return self.model(x)if __name__ == '__main__':# 示例用法num_classes = 10model = resnet18(num_classes=num_classes)# 打印模型以确认更改print(model)
2、训练代码
python">import os
import torch
import torch.nn as nn
from models.resnet18 import resnet18
from utils.utils import train_and_val,plot_acc,plot_loss,plot_lr,MyDataset
import numpy as np
from torch.utils.data import DataLoader
import glob
import pandas as pd
import configdef main(epochs,model):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")if not os.path.exists(config.save_results):os.makedirs(config.save_results)# ----------------------------模型加载-------------------------model = model.to(device)loss_function = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9)  # 每经过5个epoch,学习率乘以0.9# ------------------------------------------------------------# ---------------------------加载数据--------------------------im_train_list = glob.glob(config.train_path + "/*/*." + config.img_)im_val_list = glob.glob(config.val_path + "/*/*." + config.img_)train_dataset = MyDataset(im_train_list, config.label_names)val_dataset = MyDataset(im_val_list, config.label_names)train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True)val_loader = DataLoader(val_dataset,batch_size=config.batch_size,shuffle=False)print("num of train", len(train_dataset))print("num of val", len(val_loader))# ------------------------------------------------------------# ---------------------------网络训练--------------------------history = train_and_val(epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,config.save_results,device)df = pd.DataFrame(history) # 转换为DataFramedf.to_excel(os.path.join(config.save_results,'history.xlsx'), index=False) # 保存为 Excel 文件plot_loss(np.arange(0,epochs),config.save_results, history)plot_acc(np.arange(0,epochs),config.save_results, history)plot_lr(np.arange(0,epochs),config.save_results, history)if __name__ == '__main__':model = resnet18(num_classes=config.num_classes)main(config.epochs,model)
3、评估代码
python">from sklearn.metrics import classification_report
import torch
import os
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from models.resnet18 import resnet18
import matplotlib.pyplot as plt
from utils.utils import MyDataset,reports
from torch.utils.data import DataLoader
import seaborn as sns
import glob
import configdef main(model):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ----------------------------模型加载-------------------------model = model.to(device)checkpoint = torch.load(os.path.join(config.save_results,"best.pth"))model.load_state_dict(checkpoint, strict=True)model.eval()# ------------------------------------------------------------# ---------------------------加载数据--------------------------im_test_list = glob.glob(config.test_path + "/*/*." + config.img_)test_dataset = MyDataset(im_test_list, config.label_names)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False)print("num of test", len(test_loader))# ------------------------------------------------------------act = nn.Softmax(dim=-1)y_true, y_pred = [], []with torch.no_grad():with tqdm(total=len(test_loader)) as pbar:for images, labels in test_loader:outputs = act(model(images.to(device)))_, predicted = torch.max(outputs, 1)predicted = predicted.cpu()y_pred.extend(predicted.numpy())y_true.extend(labels.cpu().numpy())pbar.update(1)oa,aa,kappa,cls,cm = reports(y_true, y_pred)cr = classification_report(y_true, y_pred, target_names=config.label_names.values(), output_dict=True)df = pd.DataFrame(cr).transpose()df.to_csv(os.path.join(config.save_results,"classification_report.csv"), index=True)print("Accuracy is :", oa)with open(os.path.join(config.save_results,"results.txt"), "a") as file:file.write('OA:{:.4f} AA:{:.4f} kappa:{:.4f}\ncls:{}\n混淆矩阵:\n{}\n'.format(oa, aa, kappa,cls,cm))plt.figure(figsize=(10, 7))sns.heatmap(cm, annot=True, xticklabels=config.label_names.values(), yticklabels=config.label_names.values(), cmap='Blues', fmt="d")plt.xlabel('Predicted')plt.ylabel('True')plt.savefig(os.path.join(config.save_results,'test_confusion_matrix.png'))plt.clf()if __name__ == '__main__':model = resnet18()main(model)
4、结果显示

上述仅仅是简单演示,结果没有参考意义。

三、项目代码

本项目的代码通过以下链接下载:基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)


http://www.ppmy.cn/news/1526858.html

相关文章

C++ | Leetcode C++题解之第409题最长回文串

题目&#xff1a; 题解&#xff1a; class Solution { public:int longestPalindrome(string s) {unordered_map<char, int> count;int ans 0;for (char c : s)count[c];for (auto p : count) {int v p.second;ans v / 2 * 2;if (v % 2 1 and ans % 2 0)ans;}retur…

深度学习速通系列:依存分析

依存分析&#xff08;Dependency Parsing&#xff09;是自然语言处理&#xff08;NLP&#xff09;中的一项任务&#xff0c;目的是确定句子中单词之间的依存关系&#xff0c;并将这些关系表示为一个有向图&#xff0c;通常称为依存树。在依存树中&#xff0c;每个节点代表一个单…

电脑安装OpenWRT系统

通过网盘分享的文件&#xff1a;OpenWRT 链接: https://pan.baidu.com/s/1nrRBeKgGviD31Omji480qA?pwd9900 提取码: 9900 下面开始教程&#xff1a; 1.先把普通U盘制作成一个PE启动盘&#xff0c;我用的是微PE工具箱&#xff0c;直接安装PE到U盘。 2.把写盘工具和openWRT系统…

高级java每日一道面试题-2024年9月13日-基础篇-如何测试事务的正确性?

如果有遗漏,评论区告诉我进行补充 面试官: 如何测试事务的正确性&#xff1f; 我回答: 在Java高级面试中&#xff0c;测试事务的正确性是一个重要的话题&#xff0c;因为事务管理对于确保数据的一致性和完整性至关重要。事务的正确性测试通常涉及多个方面&#xff0c;包括原…

linux-系统备份与恢复-系统恢复

Linux 系统备份与恢复&#xff1a;系统恢复 1. 概述 Linux 系统的恢复是系统管理的重要组成部分&#xff0c;它指的是在系统崩溃、硬件故障、误操作或安全问题后&#xff0c;恢复系统到可用状态的过程。良好的系统恢复计划可以有效避免数据丢失和业务中断&#xff0c;并确保系…

初中生物--4.生物体的结构层次(二)

一、植物体的结构层次 1.绿色开花植物的六大器官 根、茎、叶、花、种子、果实 2.植物的组织 3.植物体的生长 植物体的生长是细胞分裂、生长和分化的综合结果。在植物体的生长过程中&#xff0c;细胞不断分裂产生新的细胞&#xff0c;新细胞不断生长使细胞体积增大&#xff…

Ubuntu搭建FTP服务器

1. 首先&#xff0c;我们需要安装和配置xinetd&#xff0c;安装的具体命令如下&#xff1a; sudo apt-get install xinetd 2. 新建tftp工作目录&#xff0c;并添加读、写、执行权限&#xff08;没有权限后面无法正常访问该文件夹&#xff09;&#xff0c;如下图所示。 3. 安装…

【PCB工艺】表面贴装技术中常见错误

系列文章目录 1.元件基础 2.电路设计 3.PCB设计 4.元件焊接 5.板子调试 6.程序设计 7.算法学习 8.编写exe 9.检测标准 10.项目举例 11.职业规划 文章目录 1、什么是SMT和SMD2、表面贴装技术的优势是什么&#xff1f;3、通孔和表面贴装技术之间的区别是什么&#xff1f;4、焊…