pytorch-MNIST测试实战

embedded/2024/11/28 0:35:53/

这里写目录标题

  • 1. 为什么test
  • 2. 如何做test
  • 3. 什么时候做test
  • 4. 完整代码

1. 为什么test

如下图:上下两幅图中蓝色分别表示train的accuracy和loss,黄色表示test的accuracy和loss,如果单纯看train的accuracy和loss曲线就会认为模型已经train的很好了,accuracy一直在上升接近于1了,loss一直在下降已经接近于0了,殊不知此时可能已经出现了over fitting(本数据集准确率很高,其他数据准确率很低),此时就需要test了,从图中可以看出test在红色划线右侧的accuracy已经不变甚至下降了,loss曲线波动也比较大,甚至已经上升了。
在这里插入图片描述

2. 如何做test

如下图所示:
argmax找出概率最大的数字的index
softmax在这里使用与不使用结果是一样的,因为softmax不改变单调性(大的依然大,小的依然小)
使用torch.eq计算预测值与目标值是否相当,相等返回1不等返回0
correct.sum().float().item() /4是用来计算accuracy的,其他sum()是计算正确的个数,item是tensor转bumpy; /4是除以总样本数
在这里插入图片描述

3. 什么时候做test

  • 每几个batch做一次
  • 一个epoch做一次
    注意:为什么不一个batch做一次test呢?因为test的数据可能也比较大,每个batch都test会影响train的速度

4. 完整代码

从一下代码可知,test是一个epoch做一次,首先像train一样load test数据,并搬到GPU中,然后数据输入到网络中,计算loss,最后计算准确了并打印输出

python">import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

http://www.ppmy.cn/embedded/5276.html

相关文章

mac修改/etc/profile导致终端所有命令不可使用

原因:配置docker环境的时候修改了/etc/profile,没想到导致悲惨事情,输入什么命令都是 Command not found 可恶!!!试了好久,最终这样搞定! 1-终端输入命令 因为sudo命令也不能直接…

【Stable Diffusion】ModuleNotFoundError: No module named ‘ifnude‘ and roop v0.0.2

提示:ModuleNotFoundError: No module named ‘ifnude’ 一、issues/299:ModuleNotFoundError: No module named ‘ifnude’ 路径 cmd 中也可以看到,路径可能有点不一样,但是后面的路径应该都是一样的,如:…

I2C,UART,SPI(STM32、51单片机)

目录 基本理论知识: 并行通信/串行通信: 异步通信/同步通信: 半双工通信/全双工通信: UART串口: I2C串口: SPI串口: I2C在单片机中的应用: 软件模拟: 51单片机:…

AI降维算法

降维算法主要分为线性降维和非线性降维两种。 线性降维方法中&#xff0c;主成分分析&#xff08;PCA&#xff09;是最基础的无监督降维算法&#xff0c;其目标是将原有的n个特征投影到k维空间&#xff08;k<n&#xff09;&#xff0c;新的特征由原特征线性变换而来&#x…

Matlab之过球面一点的平面方程

这篇文章描述2件事情&#xff1a; 1、已知球面上任意点&#xff0c;求过该点、地心、与北极点的平面方程&#xff08;即过该点的经线平面方程&#xff09;&#xff1b; 2、绕过球心的任意轴旋转平面得到新平面的方程 一、已知球面上任意点&#xff0c;求过该点、地心、与北极点…

深度学习基础——卷积神经网络的感受野、参数量、计算量

深度学习基础——卷积神经网络的感受野、参数量、计算量 深度学习在图像处理领域取得了巨大的成功&#xff0c;其中卷积神经网络&#xff08;Convolutional Neural Networks&#xff0c;CNN&#xff09;是一种非常重要的网络结构。本文将介绍卷积神经网络的三个重要指标&#…

基于CppHttpLib的Httpserver

1 背景 大多数嵌入式设备由于没有屏幕输出&#xff0c;只能通过Web页面来配置。这里利用CPPHttpLib来实现HttpServer。 2 HttpServer HttpServer是利用CPPHttpLib开源库实现的Http服务器CppHttpLib是基于C11的HTTP开源库&#xff0c;开源协议是MIT. CppHttpLib下载地址 2.1 …

【6】mysql查询性能优化-关联子查询

【README】 0. 先说结论&#xff1a;一般用inner join来改写in和exist&#xff0c;用left join来改写not in&#xff0c;not exist&#xff1b;&#xff08;本文会比较内连接&#xff0c;包含in子句的子查询&#xff0c;exist的性能 &#xff09; 1. 本文总结自高性能mysql 6…