pytorch利用简单CNN实现葡萄病虫害图片识别

embedded/2024/10/18 12:28:28/

1 前言

之前我开发了一个葡萄病虫害的可视化系统,最近就想给这个系统增加2个功能,一个是对接一个AI助手,可以进行葡萄病虫害的咨询,直接对接千问大模型,这个在之前的博文里已经介绍过对接方法了,第二个是做一个根据图片识别病虫害(分类)的功能。

2 实现思路

实现思路是想通过pytorch做一个CNN模型的训练,然后根据给出的图片进行类型的预测。

3 数据集

我没有数据集,仅有的一些图片是之前委托我做程序的bro给的,所以我们训练的时候图片并不多,不过这个没关系,数据集可以后期扩充,目前先实现功能部分

4 安装依赖

该功能由python语言实现,使用pip 安装如下依赖

torch
torchvision
matplotlib

5 数据位置

在这里插入图片描述
数据类似这样去组织,一种类型建一个文件夹,然后同一类型的图片放一起。

6 训练模型

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F# 数据预处理
transform = transforms.Compose([transforms.Resize((128, 128)),  # 调整图片大小transforms.ToTensor(),            # 转换为 Tensor
])# 加载数据集
data_dir = 'dataset'
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 获取类别标签
class_names = dataset.classes
num_classes = len(class_names)# 构建简单的 CNN 模型
class SimpleCNN(nn.Module):def __init__(self, num_classes):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(32 * 32 * 32, 128)  # 128 = (128/2)*(128/2)*(32/2)*(32/2)self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 32 * 32)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
model = SimpleCNN(num_classes)# 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(torch.cuda.is_available())
print(f'Using device: {device}')criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}")print("Training finished.")# 保存模型
torch.save(model.state_dict(), 'plant_disease_model.pth')

执行代码之后得到模型文件:
在这里插入图片描述

7 预测模型

然后我们随便去找些病虫害图片,来做预测

import torch
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F# 定义简单的 CNN 模型结构
class SimpleCNN(nn.Module):def __init__(self, num_classes):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(32 * 32 * 32, 128)self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 32 * 32)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 预测函数
def predict(image_path, model, class_names):# 定义图像预处理transform = transforms.Compose([transforms.Resize((128, 128)),  # 统一大小transforms.ToTensor(),])# 加载和预处理图像image = Image.open(image_path)image = transform(image).unsqueeze(0)  # 增加批次维度# 将图像输入模型进行预测model.eval()  # 设置模型为评估模式with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs, 1)# 返回预测的类别return class_names[predicted.item()]if __name__ == "__main__":# 加载训练好的模型num_classes = 2  # 根据你的数据集类别数量修改model = SimpleCNN(num_classes)model.load_state_dict(torch.load('plant_disease_model.pth'))model.eval()# 类别名称(根据你的数据集修改)class_names = ['disease1', 'disease2']  # 替换为实际类别名称# 测试预测test_image_path = '1.jpg'  # 替换为测试图像的路径predicted_class = predict(test_image_path, model, class_names)print(f'Predicted class: {predicted_class}')

8 结果

给出的图片和图片预测结果如下:
在这里插入图片描述


http://www.ppmy.cn/embedded/105528.html

相关文章

红帽与SUSE对RHEL/CentOS 7系列延长生命周期支持策略:保障企业Linux系统的持续安全与稳定

一、前言 昨天有幸参加了一个活动,其一主办方是SUSE,感谢SUSE的工程师提供相关信息。 在本篇文章中,我们将深入探讨两个关键的Linux操作系统支持方案:“红帽企业版 Linux 7(RHEL 7)延长生命周期支持”和“…

实现多云对象存储支持:Go 语言实践

实现多云对象存储支持:Go 语言实践 在现代云原生应用开发中,对象存储已成为不可或缺的组件。然而,不同的云服务提供商有各自的对象存储服务和 SDK。本文将介绍如何在 Go 语言中实现一个灵活的对象存储系统,支持多个主流云服务提供…

C++设计模式——Observer观察者模式

一,观察者模式的定义 观察者模式是一种行为型设计模式,又被称为"发布-订阅"模式,它定义了对象之间的一对多的依赖关系,当一个对象的状态发生变化时,所有依赖于它的对象都会收到通知并自动更新。 观察者模式…

linux nc

/* * nc */ 远程文件传输 目的主机监听 nc -l 监听端口[ 未使用端口] > 要接收的文件名 nc -l 8888 > ac.c 源主机发起请求 nc 目的主机ip 目的端口 < 要发送的文件 nc 192.168.11.21 8888 < /home/share/ac.c /* * 使用 * 收方先…

【机器学习】表示学习的基本概念和方法以及编解码结构的基本概念

引言 表示学习&#xff08;Representation Learning&#xff09;是机器学习的一个子领域&#xff0c;它专注于学习数据的表示形式&#xff0c;即数据的高层特征或抽象概念 文章目录 引言一、表示学习1.1 表示学习的重要性1.2 表示学习的方法1.3 应用场景1.4 挑战1.5 总结 二、如…

Java快速入门 知识精简(6)异常处理

异常处理 异常&#xff1a;指的是程序在执行过程中。出现的非正常的情况&#xff0c;如果不处理最终会导致JVM的非正常停止。 为保证程序正常执行&#xff0c;代码必须对可能出现的异常进行处理 说明&#xff1a; 1&#xff09;异常指的并不是语法错误&#xff1b;语法错了&…

【2024-2025源码+文档+调试讲解】微信小程序的城市公交查询系统

摘 要 当今社会已经步入了科学技术进步和经济社会快速发展的新时期&#xff0c;国际信息和学术交流也不断加强&#xff0c;计算机技术对经济社会发展和人民生活改善的影响也日益突出&#xff0c;人类的生存和思考方式也产生了变化。传统城市公交查询管理采取了人工的管理方法…

解决AutoDL远程服务器训练大模型的常见问题:CPU内存不足与 SSH 断开

在使用远程服务器&#xff08;如 AutoDL&#xff09;进行深度学习训练时&#xff0c;通常会遇到一些常见问题&#xff0c;比如由于数据加载导致的内存消耗过高&#xff0c;以及 SSH 连接中断后训练任务被迫停止。这篇文章将介绍我在这些问题上遇到的挑战&#xff0c;并分享相应…