PyTorch实现的猫狗图像分类项

embedded/2024/12/22 23:11:56/

猫狗图像分类项目

这是一个使用PyTorch实现的猫狗图像分类项目。

项目结构

  • model.py: 定义了CNN模型结构
  • train.py: 训练模型的脚本
  • predict.py: 使用训练好的模型进行预测
  • requirements.txt: 项目依赖

环境配置

  1. 创建虚拟环境(推荐)
python -m venv venv
source venv/bin/activate  # Linux/Mac
venv\Scripts\activate     # Windows
  1. 安装依赖
#requirements.txt
torch>=2.0.0
torchvision>=0.15.0
pillow>=9.0.0
numpy>=1.21.0
tqdm>=4.65.0···```bash
pip install -r requirements.txt

数据集准备

准备数据集目录结构如下:

data/
├── train/
│   ├── cat/
│   │   ├── cat1.jpg
│   │   ├── cat2.jpg
│   │   └── ...
│   └── dog/
│       ├── dog1.jpg
│       ├── dog2.jpg
│       └── ...

训练模型

python train.py

预测

修改 predict.py 中的图片路径和模型路径,然后运行:

python predict.py

模型说明

  • 使用了简单的CNN架构
  • 输入图像大小:224x224
  • 输出类别:猫(0)和狗(1)

下载图片

python download_dataset.py

import os
import urllib.request
import zipfile
from tqdm import tqdmdef download_file(url, filename):"""下载文件并显示进度条"""class DownloadProgressBar(tqdm):def update_to(self, b=1, bsize=1, tsize=None):if tsize is not None:self.total = tsizeself.update(b * bsize - self.n)with DownloadProgressBar(unit='B', unit_scale=True,miniters=1, desc=filename) as t:urllib.request.urlretrieve(url, filename=filename,reporthook=t.update_to)def prepare_dataset():# 创建数据目录os.makedirs('data/train/cat', exist_ok=True)os.makedirs('data/train/dog', exist_ok=True)# 下载示例数据集print("正在下载示例数据集...")dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip"zip_path = "dataset.zip"try:download_file(dataset_url, zip_path)print("正在解压数据集...")with zipfile.ZipFile(zip_path, 'r') as zip_ref:zip_ref.extractall("temp_data")# 移动文件到对应目录import shutilcat_source = "temp_data/PetImages/Cat"dog_source = "temp_data/PetImages/Dog"print("正在整理数据...")# 移动一部分猫的图片for i, filename in enumerate(os.listdir(cat_source)):if i >= 1000:  # 只使用1000张图片breaksrc = os.path.join(cat_source, filename)dst = os.path.join('data/train/cat', filename)try:if os.path.getsize(src) > 0:  # 检查文件是否有效shutil.copy2(src, dst)except:continue# 移动一部分狗的图片for i, filename in enumerate(os.listdir(dog_source)):if i >= 1000:  # 只使用1000张图片breaksrc = os.path.join(dog_source, filename)dst = os.path.join('data/train/dog', filename)try:if os.path.getsize(src) > 0:  # 检查文件是否有效shutil.copy2(src, dst)except:continue# 清理临时文件print("清理临时文件...")os.remove(zip_path)shutil.rmtree("temp_data")print("数据集准备完成!")print(f"猫图片数量: {len(os.listdir('data/train/cat'))}")print(f"狗图片数量: {len(os.listdir('data/train/dog'))}")except Exception as e:print(f"下载或处理数据时出错: {str(e)}")if __name__ == "__main__":prepare_dataset()

各取1000张进行训练:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
from model import CatDogNet# 设置设备 - 如果有GPU就用GPU,没有就用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 数据预处理转换
# 1. 调整图片大小为224x224
# 2. 转换为tensor格式
# 3. 标准化处理(使用ImageNet的均值和标准差)
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])def train_model(data_dir, num_epochs=10, batch_size=32, learning_rate=0.001):"""训练模型的主函数参数:data_dir (str): 数据集目录路径num_epochs (int): 训练轮数,默认10轮batch_size (int): 批次大小,默认32learning_rate (float): 学习率,默认0.001"""# 加载训练数据集# ImageFolder会自动根据子文件夹名称作为类别标签train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=transform)# 创建数据加载器# shuffle=True 确保每个epoch数据顺序随机# num_workers=4 使用4个进程加载数据train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4)# 初始化模型、损失函数和优化器model = CatDogNet().to(device)  # 将模型移到GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器# 训练循环for epoch in range(num_epochs):model.train()  # 设置为训练模式running_loss = 0.0  # 记录总损失correct = 0  # 记录正确预测数total = 0  # 记录总样本数# 创建进度条progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')for inputs, labels in progress_bar:# 将数据移到GPU(如果可用)inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 统计训练信息running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 更新进度条信息progress_bar.set_postfix({'loss': f'{running_loss/len(progress_bar):.3f}',  # 平均损失'acc': f'{100.*correct/total:.2f}%'  # 准确率})# 每5个epoch保存一次模型if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')print('训练完成!')return modelif __name__ == '__main__':# 设置数据集路径并开始训练data_dir = './data'  # 数据集路径model = train_model(data_dir)  # 开始训练模型

预测

import torch
from PIL import Image
from torchvision import transforms
from model import CatDogNetdef predict_image(image_path, model_path):# 设置设备device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型model = CatDogNet().to(device)model.load_state_dict(torch.load(model_path))model.eval()# 图像预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载并处理图像image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0).to(device)# 预测with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs, 1)# 返回预测结果class_names = ['cat', 'dog']return class_names[predicted.item()]if __name__ == '__main__':# 使用示例image_path = 'path_to_your_image.jpg'  # 修改为你的图片路径model_path = 'model_epoch_10.pth'      # 修改为你的模型路径result = predict_image(image_path, model_path)print(f'预测结果: {result}')
result, confidence = predict_image('path_to_your_image.jpg') #单张
predict_directory('path_to_your_directory') #批量

http://www.ppmy.cn/embedded/147926.html

相关文章

C# 6.0 连接elasticsearch数据库

在 C# 6.0 中连接 Elasticsearch 数据库,您可以使用官方的 Elasticsearch 客户端库 NEST。NEST 是一个高性能的 .NET 客户端,用于与 Elasticsearch 进行交互。以下是一个详细的步骤指南,帮助您在 C# 6.0 项目中连接和操作 Elasticsearch。 1. 安装 NEST 包 首先,您需要在您…

Fgui世界坐标转ui坐标的问题

在做玩家与3d物体交互的时候遇到一个问题,就是3d物体的世界坐标转换成Fgui的UI坐标,会有一点问题,在fgui的官方文档中是这么描述一个3d物体的世界坐标转换为fgui的ui坐标是这么描述的 这个应该是一个比较普遍的方案,在我的实际项目…

【线性代数】理解矩阵乘法的意义(点乘)

刚接触线性代数时,很不理解矩阵乘法的计算规则,为什么规则定义的看起来那么有规律却又莫名其妙,现在参考了一些资料,回过头重新总结下个人对矩阵乘法的理解(严格来说是点乘)。 理解矩阵和矩阵的乘法&#x…

siglip代码笔记

Github siglip-so400m-patch14-384 使用了SoViT-400m结构,SoViT :a shape-optimized vision transformer,结构参数经过试验测试得到。具体见 Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design We validate these predic…

MFC/C++学习系列之简单记录11——树控件的使用

MFC/C学习系列之简单记录11——树控件的使用 前言CTreectrl使用界面设置代码使用简单设计其他使用注意! 总结 前言 在之前的界面设计中使用得很少,但是可以学习一下,以备不时之需! CTreectrl使用 界面设置 在工具箱中选择Tree C…

数据结构—图

目录 一、图的定义 二、图的基本概念和术语 2.1有向图 2.2无向图 2.3简单图 2.4多重图 2.5完全图 2.6子图 2.7连通、连通图和连通分量 2.8强连通图、强联通分量 2.9生成树,生成森林 2.10顶点的度、入度和出度 2.11边的权和网 2.12稠密图、稀疏图 2.1…

【docker】容器编排之docker swarm

Docker Swarm容器编排详细讲解 Docker Swarm是Docker的原生容器编排工具,它通过将多个Docker引擎组合成一个集群来实现高效的容器部署和管理。 Swarm提供了服务发现、负载均衡、扩展、自动恢复等功能,能够让开发者和运维人员以更简便的方式管理容器化应…

VMWare 的克隆操作

零、碎碎念 VMWare 的这个克隆操作很简单,单拎出来成贴的目的是方便后续使用。 一、操作步骤 1.1、在“源”服务器上点右键,选择“管理--克隆” 1.2、选择“虚拟机的当前状态”为基础制作克隆,如下图所示,然后点击“下一页” 1.3、…