pytorch cnn 实现猫狗分类

ops/2025/2/22 4:15:23/

文章目录

    • @[toc]
  • 1. 导入必要的库
  • 2. 定义数据集类
  • 3. 数据预处理和加载
  • 4. 定义 CNN 模型
  • 5. 定义损失函数和优化器
  • 6. 训练模型
  • 7. 保存模型
  • 8. 使用模型进行预测
  • 9 完整代码
  • 10. 总结

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

2. 定义数据集类

我们将创建一个自定义数据集类来加载猫狗图片。


class CatDogDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformself.classes = ['cat', 'dog']self.image_paths = []self.labels = []# 遍历 cat 和 dog 目录,加载图片路径和标签for idx, class_name in enumerate(self.classes):class_dir = os.path.join(root_dir, class_name)for img_name in os.listdir(class_dir):self.image_paths.append(os.path.join(class_dir, img_name))self.labels.append(idx)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式label = self.labels[idx]if self.transform:image = self.transform(image)return image, label

3. 数据预处理和加载

定义数据预处理方法,并加载数据集。

# 数据预处理
transform = transforms.Compose([transforms.Resize((64, 64)),  # 调整图片大小transforms.ToTensor(),  # 转换为张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])# 加载数据集
train_dataset = CatDogDataset(root_dir='path_to_train_data', transform=transform)
val_dataset = CatDogDataset(root_dir='path_to_val_data', transform=transform)
# 数据加载器train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

4. 定义 CNN 模型


class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 512)self.fc2 = nn.Linear(512, 2)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = self.dropout(self.relu(self.fc1(x)))x = self.fc2(x)return xmodel = CNN()

5. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

6. 训练模型

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")# 验证模型model.eval()correct = 0total = 0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Validation Accuracy: {100 * correct / total:.2f}%")

7. 保存模型

torch.save(model.state_dict(), 'cat_dog_classifier.pth')

8. 使用模型进行预测

# 加载模型
model.load_state_dict(torch.load('cat_dog_classifier.pth'))
model.eval()# 预测函数
def predict_image(image_path):image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0).to(device)with torch.no_grad():output = model(image)_, predicted = torch.max(output, 1)return 'cat' if predicted.item() == 0 else 'dog'# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")

9 完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import osclass CatDogDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformself.classes = ['cats', 'dogs']self.image_paths = []self.labels = []# 遍历 cat 和 dog 目录,加载图片路径和标签for idx, class_name in enumerate(self.classes):class_dir = os.path.join(root_dir, class_name)num_pets = 0for img_name in os.listdir(class_dir):self.image_paths.append(os.path.join(class_dir, img_name))self.labels.append(idx)# print("class_dir : ", img_name)# num_pets = num_pets + 1# if num_pets >= 5000:#     breakdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 数据预处理
transform = transforms.Compose([transforms.Resize((64, 64)),  # 调整图片大小transforms.ToTensor(),  # 转换为张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])# 加载数据集
train_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/train', transform=transform)
val_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/valid', transform=transform)# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 512)self.fc2 = nn.Linear(512, 2)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = self.dropout(self.relu(self.fc1(x)))x = self.fc2(x)return xmodel = CNN()criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)def train():for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")# 验证模型model.eval()correct = 0total = 0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Validation Accuracy: {100 * correct / total:.2f}%")torch.save(model.state_dict(), 'cat_dog_classifier.pth')# 预测函数
def predict_image(image_path):image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0).to(device)with torch.no_grad():output = model(image)_, predicted = torch.max(output, 1)return 'cat' if predicted.item() == 0 else 'dog'def test():# 使用模型进行预测# 加载模型model.load_state_dict(torch.load('cat_dog_classifier.pth'))model.eval()image_path = 'D:/Cache/dataset/PetImages/Dog/6.jpg'image_path = 'D:/develop/pytorch/dogcat/img/training/dogs/dog1.jpg'image_path = 'D:/develop/pytorch/dogcat/img/training/cats/4.jpg'prediction = predict_image(image_path)print(f"The image is a {prediction}")import matplotlib.pyplot as pltdef test1():image_path = 'D:/develop/pytorch/dogcat/img/training/cats/3.jpg'img = Image.open(image_path).convert('RGB')transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])img = transform(img)img = img.unsqueeze(0)  # 添加batch维度model.load_state_dict(torch.load('cat_dog_classifier.pth'))model.eval()prediction = predict_image(image_path)class_names = ['cat', 'dog']print("Predicted class:", prediction)plt.imshow(img.squeeze().numpy().transpose((1, 2, 0)))plt.show()if __name__ == '__main__':train()test1()
D:\develop\pytorch\dogcat>python3.7 dogVsCat.py
C:\Users\yosola\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\PIL\TiffImagePlugin.py:864: UserWarning: Truncated File Readwarnings.warn(str(msg))
Epoch [1/10], Loss: 0.5782
Validation Accuracy: 76.25%
Epoch [2/10], Loss: 0.4676
Validation Accuracy: 81.10%
Epoch [3/10], Loss: 0.4201
Validation Accuracy: 84.90%
Epoch [4/10], Loss: 0.3605
Validation Accuracy: 88.25%
Epoch [5/10], Loss: 0.2949
Validation Accuracy: 92.50%
Epoch [6/10], Loss: 0.2234
Validation Accuracy: 95.90%
Epoch [7/10], Loss: 0.1562
Validation Accuracy: 98.00%
Epoch [8/10], Loss: 0.1069
Validation Accuracy: 98.60%
Epoch [9/10], Loss: 0.0907
Validation Accuracy: 99.70%
Epoch [10/10], Loss: 0.0785
Validation Accuracy: 99.50%
Predicted class: cat

在这里插入图片描述

10. 总结

我们定义了一个自定义数据集类 CatDogDataset 来加载猫狗图片。

使用 PyTorch 的 DataLoader 加载数据。

定义了一个简单的 CNN 模型进行训练。

保存训练好的模型,并使用模型进行预测。

你可以根据需要调整模型的架构、超参数和数据增强方法。希望这个示例对你有帮助!


http://www.ppmy.cn/ops/160421.html

相关文章

管理WSL实例 以及安装 Ubuntu 作为 WSL 子系统 流程

安装ubuntu wsl --install -d Ubuntu分类命令说明安装相关wsl --install在 Windows 10/11 上以管理员身份在 PowerShell 中运行此命令&#xff0c;可安装 WSLwsl --install -d <distribution name>在 PowerShell 中使用此命令安装特定版本的 Linux 发行版&#xff0c;如…

HTTP.

HTTP主要讲一下状态码和缓存机制 1xx 类状态码属于提示信息&#xff0c;是协议处理中的一种中间状态&#xff0c;如http升级为websocket&#xff0c;会提示1xx 2xx 类状态码表示服务器成功处理了客户端的请求 「200 OK」是最常见的成功状态码「204 No Content」也是常见的成功…

DCA考试备考

目录标题 考试内容指南一、考试环境准备&#xff08;一&#xff09;创建单实例数据库&#xff08;二&#xff09;管理数据库对象 二、数据操作&#xff08;一&#xff09;数据导入&#xff08;二&#xff09;参数修改 三、备份与恢复&#xff08;一&#xff09;备份&#xff08…

IB网络错误检查工具ibqueryerrors

ibqueryerrors 是一个用于查询 InfiniBand 网络中错误统计信息的工具。它可以帮助网络管理员识别和诊断网络问题&#xff0c;如丢包、重传和其他通信错误。这个工具通常是 InfiniBand 管理软件包的一部分&#xff0c;例如 OpenSM&#xff08;Open Subnet Manager&#xff09;。…

C++ Primer 库-IO类

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…

基于Flask框架的食谱数据可视化分析系统的设计与实现

【Flask】基于Flask框架的食谱数据可视化分析系统的设计与实现 &#xff08;完整系统源码开发笔记详细部署教程&#xff09;✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 在当今数字化时代&#xff0c;信息可视化已成为一种高效的数据理解和传播手段。…

【分布式理论11】分布式协同之分布式事务(一个应用操作多个资源):从刚性事务到柔性事务的演进

文章目录 一. 什么是分布式事务&#xff1f;二. 分布式事务的挑战三. 事务的ACID特性四. CAP理论与BASE理论1. CAP理论1.1. 三大特性1.2. 三者不能兼得 2. BASE理论 五. 分布式事务解决方案1. 两阶段提交&#xff08;2PC&#xff09;2. TCC&#xff08;Try-Confirm-Cancel&…

unity学习46:反向动力学IK

目录 1 正向动力学和反向动力学 1.1 正向动力学 1.2 反向动力学 1.3 实现目标 2 实现反向动力 2.1 先定义一个目标 2.2 动画层layer&#xff0c;需要加 IK pass 2.3 增加头部朝向代码 2.3.1 专门的IK方法 OnAnimatorIK(int layerIndex){} 2.3.2 增加朝向代码 2.4 …