用pytorch实现一个简单的图片预测类别

server/2025/2/9 12:10:52/

前言:

        在阅读本文之前,你需要了解Python,Pytorch,神经网络的一些基础知识,比如什么是数据集,什么是张量,什么是神经网络,如何简单使用tensorboard,DataLoader。

        本次模型训练使用的是cpu。

目录

python%E6%96%87%E4%BB%B6%EF%BC%9A-toc" name="tableOfContents" style="margin-left:40px">创建python文件:

model_train.py文件

1、准备数据集

2、打印看看该数据集的大小

3、加载数据集

4、创建网络模型

model.py文件

5、定义损失函数

6、定义优化器

7、设置一些训练网络所需参数

8、可视化训练过程

9、训练过程

model_vertification文件


python%E6%96%87%E4%BB%B6%EF%BC%9A" name="%E5%88%9B%E5%BB%BApython%E6%96%87%E4%BB%B6%EF%BC%9A">创建python文件:

        model.py 自定义神经网络模型。

        model_train.py 训练 CIFAR - 10 数据集上的自定义模型并保存参数。

        model_vertification.py 用一张图片验证网络模型进行预测。

下面从各个文件讲解。

model_train.py文件

完整的模型训练步骤如下:

1、准备数据集

       这里选用 CIFAR10 数据集,这个数据集是 torchvision 里面自带的,一个十分类问题的数据集,该数据集较小(160MB左右),使用torchvision.datasets模块加载 CIFAR10 数据集。

# 下载并加载 CIFAR-10 训练数据集
# root 指定数据集存储的根目录;train=True 表示加载训练集;
# transform 将数据转换为 Tensor 类型;download=True 表示如果数据集不存在则进行下载
train_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)

2、打印看看该数据集的大小

# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")

        可以看到该数据集有50000张训练图片,10000张测试图片。 

3、加载数据集

        使用DataLoader分别加载训练和测试数据集。

# 使用 DataLoader 对训练数据进行批量加载,batch_size=64 表示每个批次包含 64 个样本
train_dataloader = DataLoader(train_data, batch_size=64)
# 对测试数据进行批量加载
test_dataloader = DataLoader(test_data, batch_size=64)

4、创建网络模型

        将自定义的网络模型放在model.py文件,在train.py中导入使用。

        根据此图片的神经网络模型来自定义一个网络模型。(其中卷积层Conv2d中的参数stride和padding需要经过如下的公式计算得到,该计算并不复杂)。

        计算公式 (pytorch官网torch.nn中Conv2d中查看)

model.py文件

import torch
from torch import nnclass zzy(nn.Module):def __init__(self):super(zzy, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64,10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':zzy1 = zzy()data = torch.ones((64,3,32,32))output = zzy1(data)print(output.shape)

        运行此文件验证该网络模型是否能得到预想的结果 ,如下。

         在train.py中导入model.py后实例化网络模型。

# 创建网络模型
# 实例化自定义的网络模型 zzy
zzy1 = zzy()

5、定义损失函数

# 定义损失函数
# 使用交叉熵损失函数,常用于多分类问题
loss_fn = nn.CrossEntropyLoss()

6、定义优化器

# 优化器创建
# 学习率设置为 0.01
learning_rate = 1e-2
# 使用随机梯度下降(SGD)优化器,对 zzy1 模型的参数进行优化
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)

7、设置一些训练网络所需参数

# 设置训练网络参数
# 记录总的训练步数
total_train_step = 0
# 记录总的测试步数
total_test_step = 0
# 训练的轮数
epoch = 10

8、可视化训练过程

        为了可视化整个训练过程, 添加 TensorBoard 用于可视化训练过程。

# 在 '../logs_train' 目录下创建 SummaryWriter 对象
writer = SummaryWriter('../logs_train')

9、训练过程

        首先设置我们的训练轮次,这是外层循环,内层循环里分别进行训练数据集和测试数据集的训练。

        在训练集中,每次取出一批数据,即前面定义的64张图片,送入我们定义的网络模型得到输出,计算输出和真实值之间的损失,再进行反向传播更新模型参数进行优化,训练步数加一,再取出下一批数据,重复上面的过程。

        在测试集中,计算正确率。

# 开始训练循环,共训练 epoch 轮
for i in range(epoch):print(f'--------第{i + 1}次训练--------')# 遍历训练数据加载器中的每个批次for data in train_dataloader:# 从数据批次中解包图像和对应的标签imgs, target = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, target)# 梯度清零,防止梯度累积optimier.zero_grad()# 反向传播,计算梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimier.step()# 训练步数加 1total_train_step += 1# 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoardif total_train_step % 100 == 0:print(f'训练次数:{total_train_step},Loss:{loss.item()}')# 将训练损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('train_loss', loss.item(), total_train_step)zzy.eval()# 测试步骤开始# 初始化总的测试损失为 0total_test_loss = 0# 初始化正确率为 0total_accuracy = 0# 上下文管理器,在测试过程中不进行梯度计算,减少内存消耗with torch.no_grad():# 遍历测试数据加载器中的每个批次for data in test_dataloader:# 从数据批次中解包图像和对应的标签imgs, targets = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, targets)# 计算正确率,outputs.argmax(1)表示横向看accuracy = (outputs.argmax(1) == targets).sum()# 累加测试损失total_test_loss += loss.item()total_accuracy += accuracyprint(f'整体的测试损失: {total_test_loss}')print(f'整体的正确率: {total_accuracy/test_data_size}')# 将测试损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('test_loss', total_test_loss, total_test_step)writer.add_scalar('total_accuracy',total_accuracy/test_data_size,total_test_step)# 测试步数加 1total_test_step += 1torch.save(zzy,'zzy_{}.pth'.format(i))# 官方推荐的保存方式# torch.save(zzy.state_dict(),'zzy_{}.pth'.forma(i))print("模型已保存")
# 关闭 SummaryWriter,释放资源
writer.close()

完整的model_train.py代码如下:

import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.utils.data import DataLoader# 下载并加载 CIFAR-10 训练数据集
train_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")# 加载数据
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class zzy(nn.Module):def __init__(self):super(zzy, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 实例化自定义的网络模型 zzy
zzy1 = zzy()# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器创建
learning_rate = 1e-2
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)# 设置训练网络参数
total_train_step = 0
total_test_step = 0
epoch = 10# 添加 TensorBoard 用于可视化训练过程
writer = SummaryWriter('../logs_train')# 开始训练循环,共训练 epoch 轮
for i in range(epoch):print(f'--------第{i + 1}次训练--------')# 遍历训练数据加载器中的每个批次for data in train_dataloader:# 从数据批次中解包图像和对应的标签imgs, target = data# 将图像输入到模型中进行前向传播,得到模型的输出outputs = zzy1(imgs)# 计算模型输出与真实标签之间的损失loss = loss_fn(outputs, target)# 梯度清零,防止梯度累积optimier.zero_grad()# 反向传播,计算梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimier.step()# 训练步数加 1total_train_step += 1# 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoardif total_train_step % 100 == 0:print(f'训练次数:{total_train_step},Loss:{loss.item()}')# 将训练损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('train_loss', loss.item(), total_train_step)# 设置模型为评估模式zzy1.eval()# 测试步骤开始total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = zzy1(imgs)loss = loss_fn(outputs, targets)accuracy = (outputs.argmax(1) == targets).sum()total_test_loss += loss.item()total_accuracy += accuracyprint(f'整体的测试损失: {total_test_loss}')print(f'整体的正确率: {total_accuracy / test_data_size}')# 将测试损失添加到 TensorBoard 中,用于后续可视化writer.add_scalar('test_loss', total_test_loss, total_test_step)writer.add_scalar('total_accuracy', total_accuracy / test_data_size, total_test_step)# 测试步数加 1total_test_step += 1# 保存模型状态字典torch.save(zzy1.state_dict(), f'zzy_{i}.pth')print("模型已保存")# 关闭 SummaryWriter,释放资源
writer.close()

运行结果:

        从下图可以看到,最后训练出来的模型预测正确率为0.54左右,不算好,如果想继续优化,加大训练轮次,或者调整学习率。 

        打开tensorboard观察到整个训练过程的变化,图中深色线是经过平滑处(Smoothed)的训练损失值,能更清晰呈现损失总体变化趋势,减少波动干扰;浅色线代表原始的训练损失值,反映每个训练步骤上即时的损失情况,波动相对较大。

 

后续补充:

        想查看整个训练所用时间,可以导入time模块,设置一下开始训练时间和结束训练时间求差。

        当我把训练次数加大10倍(100次)后,模型预测的正确率为0.63左右,相对0.53没有提高很多,而且训练轮次较多时或者数据量较大时用cpu计算的时间花费就比较多了 (不要参考下面的时间,中途暂停过较长时间)。

        到中期训练了40轮次后,从正确率的变化可以看出,模型效果不佳 。

        当然可能有很多原因导致,数据量不足,网络模型结构不合理,优化器选择不合理等等,这里不过多赘述。

model_vertification文件 

        网上随便找了一个狗狗的图片保存为image.png,我们想要验证该网络模型的预测效果。

        注意:

        1、 模型训练时,模型保存方式和加载该模型的方式要对应。

        2、 需要将图片更改为网络模型能够处理的shape。

完整代码:

import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter
from model import  *
# 打开图片,使用绝对路径
image_path = r'D:\Desktop\deep_learning\pytorch入门\images\image.png'
image = Image.open(image_path)
print(image.size)# 保留颜色通道
image = image.convert('RGB')# 定义图像变换
# 首先转化尺寸,再转化为tensor类型
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()
])# 应用变换
image = transform(image)
print(f'变换后的{image.shape}')
# 增加批量维度,以匹配模型输入要求
# 图像数据,常见的输入形状是 (batch_size, channels, height, width)
# 示例说明:假设 image 是一个形状为 (3, 32, 32) 的张量,代表一张 3 通道、高度为 32、宽度为 32 的图像。
# 当调用 image.unsqueeze(0) 后,得到的新张量形状将变为 (1, 3, 32, 32),
# 这里的 1 就是新插入的批量维度,表示这个批量中只有一张图像。
# unsqueeze 只能插入一个大小为 1 的维度
image = image.unsqueeze(0)
# 或者使用这种方式来更改图像shape
# image = torch.reshape(image, (1, 3, 32, 32))print(image.shape)# 实例化模型
model = zzy()# 加载模型的状态字典
model_path = r'D:\Desktop\deep_learning\model_train\zzy_9.pth'  # 确保路径正确
model.load_state_dict(torch.load(model_path))
model.eval()# 进行前向传播
# 不要漏掉 with torch.no_grad():
# 我们的目标仅仅是根据输入数据得到模型的预测结果
# 并不需要更新模型的参数,计算梯度是不必要的开销。
with torch.no_grad():output = model(image)# 获取预测的类别
# _,表示一个占位符,只关心另一个值的输出
_, predicted = torch.max(output.data, 1)# CIFAR-10 数据集的类别名称
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 打印预测结果
print(f"预测的类别是: {classes[predicted.item()]}")

运行结果:

        正确预测到了这个类别为dog。

        作者水平有限,有任何问题或错误,欢迎留言,我将持续分享深度学习相关的内容,你的投币点赞是我最大的创作动力!

        本文代码也可以在我的github上直接下载。https://github.com/Zik-code/CIFAR-10-model_train/tree/main/model_train。
 


http://www.ppmy.cn/server/166214.html

相关文章

DeepSeek和ChatGPT的对比

最近DeepSeek大放异彩,两者之间有什么差异呢?根据了解到的信息,简单做了一个对比。 DeepSeek 和 ChatGPT 是两种不同的自然语言处理(NLP)模型架构,尽管它们都基于 Transformer 架构,但在设计目标…

【Linux基础】Linux下常用的系统命令

一、前言 本文主要总结了工作中常用的linux指令,有遇到新的命令会不定期更新。 二、系统监控和进程管理指令 2.1 ps命令 作用:查看当前进程信息。 常用选项: -e: 显示所有进程,包括其他用户的进程。-f: 显示更详细的进程信息…

模型 冗余系统(系统科学)

系列文章分享模型,了解更多👉 模型_思维模型目录。为防故障、保运行的备份机制。 1 冗余系统的应用 1.1 冗余系统在企业管理中的应用-金融行业信息安全的二倍冗余技术 在金融行业,信息安全是保障业务连续性和客户资产安全的关键。随着数字化…

新注册的域名无法访问,是怎么回事?

域名是企业和个人线上身份的标识,是对外展示信息提供服务的窗口,其重要性不言而喻。然而,不少朋友在新注册域名后,却遭遇了无法访问的尴尬情况,这到底是怎么回事呢? 域名解析尚未生效 域名注册完成后&…

React 生命周期函数详解

React 组件在其生命周期中有多个阶段,每个阶段都有特定的生命周期函数(Lifecycle Methods)。这些函数允许你在组件的不同阶段执行特定的操作。以下是 React 组件生命周期的主要阶段及其对应的生命周期函数,并结合了 React 16.3 的…

探秘数据结构之单链表:从原理到实战的深度解析

目录 一、链表的概念及结构 1.1 链表的独特定义 1.2 火车车厢式的形象类比 1.3 节点的结构体定义剖析 1.4 链表物理与逻辑结构的特性差异 二、单链表的实现 2.1 类型定义的优化策略 2.2 链表操作函数的声明框架 2.3 链表操作函数的实现细节 三、链表的分类 前言 …

ES6-代码编程风格(数组、函数)

1 数组 使用扩展运算符(...)复制数组。 const itemsCopy [...items]; 使用Array.from 方法将类似数组的对象转为数组。 const foo document.querySelectorAll(.foo); const nodes Array.from(foo); 2 函数 立即执行函数可以写成箭头函数的形式…

Maven的三种项目打包方式——pom,jar,war的区别

Maven 是一个强大的项目管理和构建工具,广泛应用于Java项目的构建和管理。Maven 支持多种打包方式,其中最常用的三种是 pom、jar 和 war。理解这三种打包方式的区别,对于正确配置和管理项目至关重要。本文将详细解释这三种打包方式的用途、特…