用pytorch实现一个简单的图片预测类别

devtools/2025/2/7 11:13:43/

前言:

        在阅读本文之前,你需要了解Python,Pytorch,神经网络的一些基础知识,比如什么是数据集,什么是张量,什么是神经网络,如何简单使用tensorboard,DataLoader。

        本次模型训练使用的是cpu。

目录

python%E6%96%87%E4%BB%B6%EF%BC%9A-toc" name="tableOfContents" style="margin-left:40px">创建python文件:

model_train.py文件

1、准备数据集

2、打印看看该数据集的大小

3、加载数据集

4、创建网络模型

model.py文件

5、定义损失函数

6、定义优化器

7、设置一些训练网络所需参数

8、可视化训练过程

9、训练过程

model_vertification文件


python%E6%96%87%E4%BB%B6%EF%BC%9A" name="%E5%88%9B%E5%BB%BApython%E6%96%87%E4%BB%B6%EF%BC%9A">创建python文件:

        model.py 自定义神经网络模型。

        model_train.py 训练 CIFAR - 10 数据集上的自定义模型并保存参数。

        model_vertification.py 用一张图片验证网络模型进行预测。

下面从各个文件讲解。

model_train.py文件

完整的模型训练步骤如下:

1、准备数据集

       这里选用 CIFAR10 数据集,这个数据集是 torchvision 里面自带的,一个十分类问题的数据集,该数据集较小(160MB左右),使用torchvision.datasets模块加载 CIFAR10 数据集。

# 下载并加载 CIFAR-10 训练数据集
# root 指定数据集存储的根目录;train=True 表示加载训练集;
# transform 将数据转换为 Tensor 类型;download=True 表示如果数据集不存在则进行下载
train_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)

2、打印看看该数据集的大小

# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")

        可以看到该数据集有50000张训练图片,10000张测试图片。 

3、加载数据集

        使用DataLoader分别加载训练和测试数据集。

# 使用 DataLoader 对训练数据进行批量加载,batch_size=64 表示每个批次包含 64 个样本
train_dataloader = DataLoader(train_data, batch_size=64)
# 对测试数据进行批量加载
test_dataloader = DataLoader(test_data, batch_size=64)

4、创建网络模型

        将自定义的网络模型放在model.py文件,在train.py中导入使用。

        根据此图片的神经网络模型来自定义一个网络模型。(其中卷积层Conv2d中的参数stride和padding需要经过如下的公式计算得到,该计算并不复杂)。

        计算公式 (pytorch官网torch.nn中Conv2d中查看)

model.py文件

import torch
from torch import nnclass zzy(nn.Module):def __init__(self):super(zzy, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64,10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':zzy1 = zzy()data = torch.ones((64,3,32,32))output = zzy1(data)print(output.shape)

        运行此文件验证该网络模型是否能得到预想的结果 ,如下。

         在train.py中导入model.py后实例化网络模型。

# 创建网络模型
# 实例化自定义的网络模型 zzy
zzy1 = zzy()

5、定义损失函数

# 定义损失函数
# 使用交叉熵损失函数,常用于多分类问题
loss_fn = nn.CrossEntropyLoss()

6、定义优化器

# 优化器创建
# 学习率设置为 0.01
learning_rate = 1e-2
# 使用随机梯度下降(SGD)优化器,对 zzy1 模型的参数进行优化
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)

7、设置一些训练网络所需参数

# 设置训练网络参数
# 记录总的训练步数
total_train_step = 0
# 记录总的测试步数
total_test_step = 0
# 训练的轮数
epoch = 10

8、可视化训练过程

        为了可视化整个训练过程, 添加 TensorBoard 用于可视化训练过程。

# 在 '../logs_train' 目录下创建 SummaryWriter 对象
writer = SummaryWriter('../logs_train')

9、训练过程

        首先设置我们的训练轮次,这是外层循环,内层循环里分别进行训练数据集和测试数据集的训练。

        在训练集中,每次取出一批数据,即前面定义的64张图片,送入我们定义的网络模型得到输出,计算输出和真实值之间的损失,再进行反向传播更新模型参数进行优化,训练步数加一,再取出下一批数据,重复上面的过程。

        在测试集中,计算正确率。

# 开始训练循环,共训练 epoch 轮
for i in range(epoch):print(f'--------第{i + 1}次训练--------')# 遍历训练数据加载器中的每个批次for data in train_dataloader:# 从数据批次中解包图像和对应的标签imgs, target = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, target)# 梯度清零,防止梯度累积optimier.zero_grad()# 反向传播,计算梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimier.step()# 训练步数加 1total_train_step += 1# 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoardif total_train_step % 100 == 0:print(f'训练次数:{total_train_step},Loss:{loss.item()}')# 将训练损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('train_loss', loss.item(), total_train_step)zzy.eval()# 测试步骤开始# 初始化总的测试损失为 0total_test_loss = 0# 初始化正确率为 0total_accuracy = 0# 上下文管理器,在测试过程中不进行梯度计算,减少内存消耗with torch.no_grad():# 遍历测试数据加载器中的每个批次for data in test_dataloader:# 从数据批次中解包图像和对应的标签imgs, targets = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, targets)# 计算正确率,outputs.argmax(1)表示横向看accuracy = (outputs.argmax(1) == targets).sum()# 累加测试损失total_test_loss += loss.item()total_accuracy += accuracyprint(f'整体的测试损失: {total_test_loss}')print(f'整体的正确率: {total_accuracy/test_data_size}')# 将测试损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('test_loss', total_test_loss, total_test_step)writer.add_scalar('total_accuracy',total_accuracy/test_data_size,total_test_step)# 测试步数加 1total_test_step += 1torch.save(zzy,'zzy_{}.pth'.format(i))# 官方推荐的保存方式# torch.save(zzy.state_dict(),'zzy_{}.pth'.forma(i))print("模型已保存")
# 关闭 SummaryWriter,释放资源
writer.close()

完整的model_train.py代码如下:

import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.utils.data import DataLoader# 下载并加载 CIFAR-10 训练数据集
train_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")# 加载数据
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class zzy(nn.Module):def __init__(self):super(zzy, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 实例化自定义的网络模型 zzy
zzy1 = zzy()# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器创建
learning_rate = 1e-2
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)# 设置训练网络参数
total_train_step = 0
total_test_step = 0
epoch = 10# 添加 TensorBoard 用于可视化训练过程
writer = SummaryWriter('../logs_train')# 开始训练循环,共训练 epoch 轮
for i in range(epoch):print(f'--------第{i + 1}次训练--------')# 遍历训练数据加载器中的每个批次for data in train_dataloader:# 从数据批次中解包图像和对应的标签imgs, target = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, target)# 梯度清零,防止梯度累积optimier.zero_grad()# 反向传播,计算梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimier.step()# 训练步数加 1total_train_step += 1# 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoardif total_train_step % 100 == 0:print(f'训练次数:{total_train_step},Loss:{loss.item()}')# 将训练损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('train_loss', loss.item(), total_train_step)# 设置模型为评估模式zzy1.eval()# 测试步骤开始total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = zzy1(imgs)loss = loss_fn(outputs, targets)accuracy = (outputs.argmax(1) == targets).sum()total_test_loss += loss.item()total_accuracy += accuracyprint(f'整体的测试损失: {total_test_loss}')print(f'整体的正确率: {total_accuracy / test_data_size}')# 将测试损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('test_loss', total_test_loss, total_test_step)writer.add_scalar('total_accuracy', total_accuracy / test_data_size, total_test_step)# 测试步数加 1total_test_step += 1# 保存模型状态字典torch.save(zzy1.state_dict(), f'zzy_{i}.pth')print("模型已保存")# 关闭 SummaryWriter,释放资源
writer.close()

运行结果:

        从下图可以看到,最后训练出来的模型预测正确率为0.54左右,不算好,如果想继续优化,加大训练轮次,或者调整学习率。 

        打开tensorboard观察到整个训练过程的变化,图中深色线是经过平滑处(Smoothed)的训练损失值,能更清晰呈现损失总体变化趋势,减少波动干扰;浅色线代表原始的训练损失值,反映每个训练步骤上即时的损失情况,波动相对较大。

 

后续补充:

        想查看整个训练所用时间,可以导入time模块,设置一下开始训练时间和结束训练时间求差。

        当我把训练次数加大10倍(100次)后,模型预测的正确率为0.63左右,相对0.53没有提高很多,而且训练轮次较多时或者数据量较大时用cpu计算的时间花费就比较多了 (不要参考下面的时间,中途暂停过较长时间)。

        到中期训练了40轮次后,从正确率的变化可以看出,模型效果不佳 。

        当然可能有很多原因导致,数据量不足,网络模型结构不合理,优化器选择不合理等等,这里不过多赘述。

model_vertification文件 

        网上随便找了一个狗狗的图片保存为image.png,我们想要验证该网络模型的预测效果。

        注意:

        1、 模型训练时,模型保存方式和加载该模型的方式要对应。

        2、 需要将图片更改为网络模型能够处理的shape。

完整代码:

import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter
from model import  *
# 打开图片,使用绝对路径
image_path = r'D:\Desktop\deep_learning\pytorch入门\images\image.png'
image = Image.open(image_path)
print(image.size)# 保留颜色通道
image = image.convert('RGB')# 定义图像变换
# 首先转化尺寸,再转化为tensor类型
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()
])# 应用变换
image = transform(image)
print(f'变换后的{image.shape}')
# 增加批量维度,以匹配模型输入要求
# 图像数据,常见的输入形状是 (batch_size, channels, height, width)
# 示例说明:假设 image 是一个形状为 (3, 32, 32) 的张量,代表一张 3 通道、高度为 32、宽度为 32 的图像。
# 当调用 image.unsqueeze(0) 后,得到的新张量形状将变为 (1, 3, 32, 32),
# 这里的 1 就是新插入的批量维度,表示这个批量中只有一张图像。
# unsqueeze 只能插入一个大小为 1 的维度
image = image.unsqueeze(0)
# 或者使用这种方式来更改图像shape
# image = torch.reshape(image, (1, 3, 32, 32))print(image.shape)# 实例化模型
model = zzy()# 加载模型的状态字典
model_path = r'D:\Desktop\deep_learning\model_train\zzy_9.pth'  # 确保路径正确
model.load_state_dict(torch.load(model_path))
model.eval()# 进行前向传播
# 不要漏掉 with torch.no_grad():
# 我们的目标仅仅是根据输入数据得到模型的预测结果
# 并不需要更新模型的参数,计算梯度是不必要的开销。
with torch.no_grad():output = model(image)# 获取预测的类别
# _,表示一个占位符,只关心另一个值的输出
_, predicted = torch.max(output.data, 1)# CIFAR-10 数据集的类别名称
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 打印预测结果
print(f"预测的类别是: {classes[predicted.item()]}")

运行结果:

        正确预测到了这个类别为dog。

        作者水平有限,有任何问题或错误,欢迎留言,我将持续分享深度学习相关的内容,你的投币点赞是我最大的创作动力!

        本文代码也可以在我的github上直接下载。https://github.com/Zik-code/CIFAR-10-model_train/tree/main/model_train。
 


http://www.ppmy.cn/devtools/156790.html

相关文章

Sui 年度展望:2025 是走向主流的一年,将 Sui 打造成体验最友好的平台

作者:Adeniyi.sui 编译:深潮 TechFlow Mysten Labs 正与 CarnegieMellon (卡内基梅隆大学)的研究人员紧密合作,共同开发和优化可编程的点对点 (P2P) 隧道。这项技术将为区块链的应用场景带来更多可能性。 展望 2025…

3. k8s二进制集群之负载均衡器高可用部署

Haproxy 和 Keepalived安装Haproxy配置文件准备Keepalived配置及健康检查启动Haproxy & Keepalived服务继续上一篇文章《K8S集群架构及主机准备》,下面介绍负载均衡器搭建过程 Haproxy 和 Keepalived安装 在负载均衡器两个主机上安装即可 apt install haproxy keepalived…

vue 弹窗 模板

<template><el-dialogtitle"选择批号":visible.sync"showFlag"width"800px"append-to-body:destroy-on-close"true"open"handleOpen"><el-form :model"queryParams" ref"queryForm" :in…

Excel交换列位置

在Excel中拖动列以调整位置&#xff0c;可以按照以下步骤操作&#xff1a; 方法一&#xff1a;使用鼠标拖动 选择整列&#xff1a;点击列标&#xff08;如A、B、C&#xff09;选中要移动的列。拖动列&#xff1a; 将鼠标移到列边框&#xff0c;光标变为四向箭头。按住鼠标左键…

JVM监控和管理工具

基础故障处理工具 jps jps(JVM Process Status Tool)&#xff1a;Java虚拟机进程状态工具 功能 1&#xff1a;列出正在运行的虚拟机进程 2&#xff1a;显示虚拟机执行主类(main()方法所在的类) 3&#xff1a;显示进程ID(PID&#xff0c;Process Identifier) 命令格式 jps […

牛客比赛贪心算法

题目如下 代码及解析如下 谢谢观看&#xff01;&#xff01;&#xff01;

[250202] DocumentDB 开源发布:基于 PostgreSQL 的文档数据库新选择 | Jekyll 4.4.0 发布

目录 DocumentDB 开源发布&#xff1a;基于 PostgreSQL 的文档数据库新选择DocumentDB 的使命DocumentDB 的架构 Jekyll 4.4.0 版本发布&#x1f195; 新特性与改进 DocumentDB 开源发布&#xff1a;基于 PostgreSQL 的文档数据库新选择 微软近日宣布开源 DocumentDB&#xff…

【自学笔记】Python的基础知识点总览-持续更新

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 Python基础知识总览1. Python简介2. 安装与环境配置3. 基本语法3.1 变量与数据类型3.2 控制结构3.3 函数与模块3.4 文件操作 4. 面向对象编程&#xff08;OOP&#…