模型训练套路(二)

embedded/2024/9/20 1:22:45/ 标签: 神经网络, 人工智能, 深度学习, 学习, pytorch

接模型训练套路(一)http://t.csdnimg.cn/gZ4Fm

得到预测的值:preds=[1][1],

输出目标:inputs target = [0][1];

查看两者的正确率,就看:preds==inputs target

输出的结果:[false][true].sum = 1

一、计算正确率的代码:

import torchoutput = torch.tensor([[0.1,0.2],[0.3,0.4]])
# argmax(1)表示横向比较
print(output.argmax(1))
preds = output.argmax(1)targets = torch.tensor([0, 1])print((preds == targets).sum())

接训练一代码,计算整体的计算率:在测试步骤开始中写:

total_accuracy = 0
accuracy = (outputs.argmax(1)==targets).sum()
total_accuracy = total_accuracy + accuracy
writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)

模型训练套路(三)

模型包括:1、数据集——>2、加载数据集——>3、创建网络模型——>4、损失函数——>5、优化器——>6、设置训练参数——>7、设置训练epoch——>8、训练开始(优化器优化模型、展示数据)——>9、测试开始(设置梯度为0,从测试数据集中取数据,计算loss,计算误差,构建需要指标并显示,展示训练的网络在测试集上的效果)——>10、保存模型

完整代码:

能够使用gpu的语句:

网络模型;

数据(输入,标注);imgs,targets

损失函数;

​
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from model1 import*train_data = torchvision.datasets.CIFAR10(root = "../data", train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root = "../data", train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 查看数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 格式化
print("训练数据集的长度为: {}".format(train_data_size))
print("测试数据集的长度为: {}".format(test_data_size))# 利用dataloader加载数据集
train_dataloader = DataLoader (train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)# 创建网络模型
sun = SUN()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器(SGD随机梯度下降)
learning_rate = 0.01
optimizer = torch.optim.SGD(sun.parameters(), lr = learning_rate)# 设置网络训练的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("../train_logs")for i in range(epoch):print("-------第{}轮训练开始--------".format(i+1))# 训练网络模型,从训练的data中取数据# 训练步骤开始sun.train()  # 该语句并不一定要写for data in train_dataloader:imgs, targets = dataoutputs = sun(imgs)# 将得到的输出与真实的target比较,得到误差loss =loss_fn(outputs, targets)# 优化器优化模型# 进行优化,首先是梯度清零optimizer.zero_grad()# 得到每个节点的梯度loss.backward()# 对其中的参数进行优化optimizer.step()total_test_step = total_test_step + 1if total_test_step % 100 == 0:print("训练次数:{}, Loss: {}".format(total_test_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 如何知道数据训练好了没有# 利用现有模型进行测试# 在测试数据集上走一遍,以测试数据集的损失,来判定模型训练好了没有# 测试过程中不需要在对模型进行调优# 测试步骤开始sun.eval() # 该语句也不一定要写total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = sun(imgs)loss =loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1)==targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集的正确率:{}".format(total_accuracy))writer.add_scalar("test_loss",total_test_loss, total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step +=1torch.save(sun, "sun_{}.path".format(i))print("模型已保存")writer.close()​​

 

 


http://www.ppmy.cn/embedded/108341.html

相关文章

前端WebSocket客户端实现

// 创建WebSocket连接 var socket new WebSocket(ws://your-spring-boot-server-url/websocket-endpoint);// 连接打开时触发 socket.addEventListener(open, function (event) {socket.send(JSON.stringify({type: JOIN, room: general})); });// 监听从服务器来的消息 socke…

K8S日志收集

本章主要讲解在 Kubernetes 集群中如何通过不同的技术栈收集容器的日志,包括程序直接输出到控制台日志、自定义文件日志等。 一、有哪些日志需要收集 为了更加方便的处理异常,日志的收集与分析极为重要,在学习日志收集之前,需要知…

QT基础 QPropertyAnimation简单学习

目录 1.简单介绍 2.使用步骤 3.部分代码示例 4.多项说明 5.信号反馈 6.自定义属性 1. 定义自定义属性 2. 使用 QPropertyAnimation 动画化自定义属性 3. 连接信号和槽 4.注意事项 7.更多高级示例 1.简单介绍 QPropertyAnimation是Qt中的一个类,用于实现属性…

idea安装并使用maven依赖分析插件:Maven Helper

在 IntelliJ IDEA 中安装并使用 Maven Helper 插件可以帮助你更方便地管理 Maven 项目的依赖,比如查看依赖树、排除冲突依赖等。以下是安装和使用 Maven Helper 插件的步骤: 安装 Maven Helper 插件 打开 IntelliJ IDEA 并进入你的项目。 在 IDE 的右下…

百度飞浆OCR半自动标注软件OCRLabel配置【详细

今天帮标注人员写了一份完整的百度飞浆OCR标注软件的安装配置说明书、以供标注人员使用 包括各种环境安装包一起分享出来【conda\python\label项目包、清华源配置文件、pycharm社区版安装包】 提取码:umys 1、解压并安装tools文件下的miniconda,建议安装在D盘下的…

Win32绕过UAC弹窗获取管理员权限

在早些年写一些桌面软件时,需要管理员权限,但是又不想UAC弹窗,所以一般是直接将UAC的级别拉到最低,或者直接禁用UAC的相关功能。 什么是UAC(User Account Control) 用户帐户控制 (UAC) 是一项 Windows 安全功能,旨在保…

Flink SQL 中常见的数据类型

Flink SQL 中常见的数据类型 目标 通过了解Flink SQL 中常见的数据类型,掌握正确编写Flink SQL 语句背景 Apache Flink 支持多种数据类型,这些数据类型被用于 Flink SQL 表达式、Table API 以及 DataStream API 中。以下是 Flink SQL 中常见的数据类型: 基本数据类型 Boo…

<Rust>egui学习之部件(十一):如何在窗口中添加单选框radiobutton部件?

前言 本专栏是关于Rust的GUI库egui的部件讲解及应用实例分析,主要讲解egui的源代码、部件属性、如何应用。 环境配置 系统:windows 平台:visual studio code 语言:rust 库:egui、eframe 概述 本文是本专栏的第十一篇…

算法练习题17——leetcode54螺旋矩阵

题目描述 给你一个 m 行 n 列的矩阵 matrix &#xff0c;请按照 顺时针螺旋顺序 &#xff0c;返回矩阵中的所有元素。 代码 import java.util.*;class Solution {public List<Integer> spiralOrder(int[][] matrix) {// 用于存储螺旋顺序遍历的结果List<Integer>…

学习整理使用jquery实现获取相同name被选中的多选框值的方法

学习整理使用jquery实现获取相同name被选中的多选框值的方法 <html><head><meta charset"gbk"><!-- 引入JQuery --><script src"https://www.qipa250.com/jquery/dist/jquery.min.js" type"text/javascript"><…

flutter开发实战-flutter build web微信无法识别二维码及小程序码问题

flutter开发实战-flutter build web微信无法识别二维码及小程序码问题 GitHub Pages是一个直接从GitHub存储库托管的静态站点服务&#xff0c;‌它允许用户通过简单的配置&#xff0c;‌将个人的代码项目转化为一个可以在线访问的网站。‌这里使用flutter build web来构建web发…

现代计算机中数字的表示与浮点数、定点数

现代计算机中数字的表示与浮点数、定点数 导读&#xff1a;浮点数运算是一个非常有技术含量的话题&#xff0c;不太容易掌握。许多程序员都不清楚使用操作符比较float/double类型的话到底出现什么问题。这篇文章讲述了浮点数的来龙去脉&#xff0c;所有的软件开发人员都应该读…

【网络安全】服务基础第二阶段——第三节:Linux系统管理基础----Linux用户与组管理

目录 一、用户与组管理命令 1.1 用户分类与UID范围 1.2 用户管理命令 1.2.1 useradd 1.2.2 groupadd 1.2.3 usermod 1.2.4 userdel 1.3 组管理命令 1.3.1 groupdel 1.3.2 查看密码文件 /etc/shadow 1.3.4 passwd 1.4 Linux密码暴力破解 二、权限管理 2.1 文件与目…

封装触底加载组件

&#xff08;1&#xff09;首先创建一个文件名为&#xff1a;InfiniteScroll.vue <template><div ref"scrollContainer" class"infinite-scroll-container"><slot></slot><div v-if"loading" class"loading-sp…

nginx 新建一个 PC web 站点

注意&#xff1a;进行实例之前必须完成nginx的源码编译。&#xff08;阅读往期文章完成步骤&#xff09; 1.编辑nginx的配置文件&#xff0c;修改内容 [rootlocalhost ~]# vim /usr/local/nginx/conf/nginx.conf 2.创建新目录/usr/local/nginx/conf.d/&#xff0c;编辑新文件…

【激活函数总结】Pytorch中的激活函数详解: ReLU、Leaky ReLU、Sigmoid、Tanh 以及 Softmax

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发…

问:你知道IO和NIO有哪些区别不?

一、先表示一下_ Java IOJava NIO主要特点面向流&#xff08;Stream&#xff09;的I/O操作面向缓冲区&#xff08;Buffer&#xff09;和通道&#xff08;Channel&#xff09;的I/O操作&#xff0c;支持非阻塞I/O和选择器&#xff08;Selector&#xff09;常用方法InputStream、…

el-table 单元格,双击编辑

el-table 单元格&#xff0c;双击编辑 实现效果 代码如下 <template><el-table :data"tableData" style"width: 100%"><el-table-column prop"name" label"姓名" width"180"><template slot-scope&q…

flutter开发实战-flutter build web发布到github page及图片未显示问题

flutter开发实战-flutter build web发布到github page及图片未显示问题 GitHub Pages是一个直接从GitHub存储库托管的静态站点服务&#xff0c;‌它允许用户通过简单的配置&#xff0c;‌将个人的代码项目转化为一个可以在线访问的网站。‌这里使用flutter build web来构建web…

反向沙箱-安全上网解决方案

随着信息化的发展&#xff0c;企业日常办公越来越依赖互联网。终端以及普通PC终端在访问互联网过程中&#xff0c;会遇到各种各样不容忽视的风险&#xff0c;例如员工主动故意的数据泄漏&#xff0c;后台应用程序偷偷向外部发信息&#xff0c;木马间谍软件的外联&#xff0c;以…