完整的模型训练路线

devtools/2024/9/18 12:36:08/ 标签: 机器学习, 深度学习, 人工智能, pytorch, pycharm

1.完整的模型训练套路:

完成CIFAR10的分类问题

1.1准备数据集:

其实用len去查看数据集的长度已经不是新知识点了。当我们要重写Dataset类的时候,关键需要重写Dataset类的__len__()方法和__getitem__()方法。

train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练集的长度为:{}".format(train_data_size))
print("测试的长度为:{}".format(test_data_size))

1.2利用DataLoader来加载数据集:

# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

1.3搭建神经网络:

将搭建的网络模型放入单独的一个model.py文件中,并进行验证。

import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()input=torch.ones((64,3,32,32))output=tudui(input)print(output.shape)

1.4创建网络模型:

按住Ctril然后点击类名可以查看源代码。

from model import *#创建网络模型
tudui = Tudui()

1.5创建损失函数:

#创建损失函数
loss_fn=nn.CrossEntropyLoss()

1.6设置优化器:

推荐使用科学计数法表示学习率。

#定义优化器
learning_rate=1e-2
#learning_rate=0.01
optimizer=torch.optim.SGD(tudui.parameters(),lr=learning_rate)

1.7设置训练网络的一些参数:

#设置训练网络的一些参数
#记录训练的次数
total_train_step=0
#记录测试的次数
total_test_step=0
#训练的轮数
epoch=10
for i in range(epoch):print("--------第{}轮训练开始----------".format(i+1))#训练步骤开始for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#优化器调优optimizer.zero_grad()loss.backward()optimizer.step()total_train_step=total_train_step+1print("训练次数:{},loss:{}".format(total_train_step,loss.item()))

在这里插入图片描述

2.完整的模型测试:

2.1设置测试部分:

用with torch.no_grad():环境取消梯度。

for i in range(epoch):print("--------第{}轮训练开始----------".format(i+1))#训练步骤开始for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#优化器调优optimizer.zero_grad()loss.backward()optimizer.step()total_train_step=total_train_step+1if total_train_step%100==0:print("训练次数:{},loss:{}".format(total_train_step,loss.item()))#测试步骤开始total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss=total_test_loss+lossprint("整体测试集上的Loss:{}".format(total_test_loss))

在这里插入图片描述

2.2用tensorboard显示loss的图像:

添加参数

#添加 tensorboard
writer=SummaryWriter("./logs_train")

在训练步骤中添加:

total_train_step=total_train_step+1if total_train_step%100==0:print("训练次数:{},loss:{}".format(total_train_step,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)

在测试步骤后添加:

writer.add_scalar("test_loss",total_test_loss,total_test_step)total_test_step=total_test_step+1

在for i in range(epoch)循环外添加:

writer.close()

在这里插入图片描述

2.3保存训练参数:

在for i in range(epoch)外添加:

torch.save(tudui,"tudui_{}.pth".format(i))print("模型已保存")

2.4利用torch.argmax函数计算准确率:

  • torch.argmax(predictions, dim=0)
    • dim=0:沿着行方向(样本方向)获取最大值的索引。
    • dim=1:沿着列方向(特征方向)获取最大值的索引。
total_accuracy=0
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)

在这里插入图片描述

3.训练细节总结:

  • model.train() 和 model.eval ()
    • 在官网的torch.nn.Module小节中可以查看train 和eval
    • model.train() 将模块设置为训练模式。这只对某些模块有影响,例如Dxopout、BatchNorm等。
    • model.eval ()将模块设置为验证模式。这只对某些模块有影响。这等效于self.Train(False)。
    • 最好还是加上。

在训练开始前加上.train(),在测试开始前加上.eval()。

4.完整代码

train.py

import tensorboard
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch import nn
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练集的长度为:{}".format(train_data_size))
print("测试的长度为:{}".format(test_data_size))# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)#创建网络模型
tudui = Tudui()#创建损失函数
loss_fn=nn.CrossEntropyLoss()#定义优化器
learning_rate=0.01
optimizer=torch.optim.SGD(tudui.parameters(),lr=learning_rate)#设置训练网络的一些参数
#记录训练的次数
total_train_step=0
#记录测试的次数
total_test_step=0
#训练的轮数
epoch=10#添加 tensorboard
writer=SummaryWriter("./logs_train")for i in range(epoch):print("--------第{}轮训练开始----------".format(i+1))#训练步骤开始tudui.train()for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#优化器调优optimizer.zero_grad()loss.backward()optimizer.step()total_train_step=total_train_step+1if total_train_step%100==0:print("训练次数:{},loss:{}".format(total_train_step,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)#测试步骤开始tudui.eval()total_test_loss=0total_accuracy=0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss=total_test_loss+loss.item()accuracy=(outputs.argmax(1)==targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step=total_test_step+1torch.save(tudui,"tudui_{}.pth".format(i))# torch.save(tudui.state_dict(),"tudui_{}.pth".format(i))print("模型已保存")writer.close()

model.py

import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()input=torch.ones((64,3,32,32))output=tudui(input)print(output.shape)

http://www.ppmy.cn/devtools/102888.html

相关文章

5步掌握“花开富贵”花园管理系统开发——基于Python Django+Vue

🍊作者:计算机毕设匠心工作室 🍊简介:毕业后就一直专业从事计算机软件程序开发,至今也有8年工作经验。擅长Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等。 擅长:按照需求定制化开发项目…

vue3 使用vue-masonry加载更多,重新渲染

在使用 van-list做上拉加载更多,加载下一页的时候,会出现瀑布图重叠,原因是布局没有重新更新,所以需要 调用 vue-masonry更新布局的方法。 看了源码才知道可以这样用,api都没写,隐藏太深了。。。 vue3中通…

CMake学习

Cmake 工具链 预处理器(宏替换)->编译器->汇编器(二进制文件.obj&&.0)->链接器->变成.exe 单源文件可以直接命令生成.exe 解决1:源文件多时要编写makefile,但编写makefile文件很麻烦解决2:使用cmake,跨平台&…

Yololov5+Pyqt5+Opencv 实时城市积水报警系统

在现代城市生活中,积水问题不仅影响交通和人们的日常生活,还可能对城市基础设施造成潜在的威胁。为了快速、准确地识别和应对积水问题,使用计算机视觉技术进行智能积水检测成为一个重要的解决方案。在这篇博客中,我将带你一步步实…

FastAPI vs Flask: 专业对比与选择

FastAPI与Flask是两个流行的Python Web框架,它们在构建Web应用程序和API方面各有特点。以下是对这两个框架的详细比较: 一、设计理念与用途 Flask: 是一个轻量级的Python Web框架,基于Werkzeug WSGI工具箱和Jinja2模板引擎。设计…

OSI和TCP/IP参考模型、协议与端口、DNS解析类型、数据封装

目录 1.OSI和TCP/IP参考模型 1.1 为什么要进行网络分层? 1.2 TCP/IP和OSI参考模型 1.3 TCP/IP参考模型对应协议 2.对应协议和端口 3.基于IP的封装 4.DNS解析类型 5.数据封装与解封过程分析 5.1 封装 1.OSI和TCP/IP参考模型 1.1 为什么要进行网络分层&am…

Apache SeaTunnel Zeta 引擎源码解析(一)Server端的初始化

引入 本系列文章是基于 Apache SeaTunnel 2.3.6版本,围绕Zeta引擎给大家介绍其任务是如何从提交到运行的全流程,希望通过这篇文档,对刚刚上手SeaTunnel的朋友提供一些帮助。 我们整体的文章将会分成三篇,从以下方向给大家介绍&am…

开关电源的基础特性

开关电源是一种利用高频开关技术将电能转换为所需电压或电流的设备。在现代电子设备中,开关电源由于其高效、紧凑和可调节的特性,被广泛应用于各种领域。 1. 纹波与噪声 开关电源的输出电压或电流并非绝对恒定,而是存在周期性的小幅度波动…

排行榜系统设计:高并发场景下的最佳实践

Hello,大家好!我是你们的技术分享小伙伴小米,29岁,喜欢技术,也喜欢分享各种有趣的项目经验。今天,我们来聊聊如何设计一个排行榜。 无论是游戏中的战力排行榜,还是电商平台的热销产品榜单,排行榜都在我们生活中扮演了重要的角色。而作为一个技术人,设计一个高效、稳定…

14.JS学习篇-CSR和SSR

在前端开发中,CSR(Client-Side Rendering,客户端渲染)和 SSR(Server-Side Rendering,服务端渲染)是两种不同的渲染方式。 一、CSR(客户端渲染) 1.工作原理:…

介绍 Java 的集合类

Java 的集合框架提供了一组标准化的接口和类,用于存储和操作一组对象(元素)。这些集合类位于 java.util 包中,并可以分为几大类:List、Set、Queue、Deque 和 Map。每一类集合都提供了不同的功能和特性,以满足不同的使用场景。 1. List 接口 List 是一种有序的集合,可以…

第六届机器人与智能制造技术国际会议 (ISRIMT 2024)

重要信息 大会官网:www.isrimt.org(点击了解大会,参会,投稿等信息) 大会时间:2024年9月20-22日 大会地点:中国-江苏常州 收录检索:IEEE Xplore, EI Compendex, Scopus 大会简介…

RabbitMQ 集群与高可用性

目录 单节点与集群部署 1.1. 单节点部署 1.2. 集群部署 镜像队列 1.定义与工作原理 2. 配置镜像队列 3.应用场景 4. 优缺点 5. Java 示例 分布式部署 1. 分布式部署的主要目标 2. 典型架构设计 3. RabbitMQ 分布式部署的关键技术 4. 部署策略和实践 5. 分布式部署…

【有来开源组织】开发规范手册

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall 🍃 vue3-element-admin 🍃 youlai-boot 🌺 仓库主页: Gitee 💫 Github 💫 GitCode 💖 欢迎点赞…

应对Nginx负载均衡中的请求超时:策略与配置

在Nginx负载均衡的部署中,处理请求超时是一个关键问题。请求超时不仅影响用户体验,还可能隐藏着后端服务的性能瓶颈。合理配置Nginx以处理超时情况,可以显著提高服务的稳定性和可靠性。本文将详细介绍如何在Nginx负载均衡中处理请求超时&…

1.两数之和

给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案,并且你不能使用两次相同的元素。 你可以按任意顺序返回答案。 审题最重…

C# 实现傅里叶变化(DFT)

1、DFT函数类 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace DFT_FFTApp.Utils {public class DFT{/// <summary>/// DFT/// </summary>/// <param name="data"&…

AI的未来已来:GPT-4商业应用带来的无限可能

随着人工智能技术的快速发展&#xff0c;OpenAI于2023年3月15日发布了多模态预训练大模型GPT-4&#xff0c;这一里程碑式的进步不仅提升了AI的语言处理能力&#xff0c;还拓展了其应用范围。本文将深入探讨GPT-4的技术进步、商业化进程、用户体验改善、伦理和社会影响&#xff…

【drools】Rulesengine构建及intelj配置

7.57.0.FinalRulesengineApplication 使用maven构建 intelj 打开文件资源管理器实在是太慢了所以直接把pom 扔到其主页识别为maven项目,自动下载maven包管理器 然后解析依赖: 给maven加一个代理 -DproxyHost=127.0.0.1 -DproxyPort=7890 还是卡主

运维监控工具 PIGOSS BSM :PostgreSQL数据库监控指标

在PostgreSQL数据库中&#xff0c;为了确保其稳定运行和性能优化&#xff0c;我们需要监控一系列关键的指标。以下是一些主要的PostgreSQL监控指标介绍&#xff1a; 连接数&#xff08;Connections&#xff09;&#xff1a; 定义&#xff1a;连接数是指当前正在与数据库建立连…