【python】pytorch包(第四章)手写数字图像识别

news/2024/11/30 2:35:52/

问题描述:

给定手写字体的图片,人工智能自动判断这是数字几

数据来源:

MNIST数据集

代码实战:

Part 1. 准备数据集

该模块内容完成的功能:

  1. 下载MNIST数据集;
  2. 转换数据格式,使适用于pytorch;
  3. 数据分批;
  4. 将上述功能 API化
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
def get_data(Batch_Size,train=True):#train = True则是训练集,否则是测试集#以每批Batch_size大小进行数据分批transform_fn = Compose([ToTensor(), #转张量Normalize(mean=(0.1307,),std=(0.3081))#正则化#mean和std的形状和通道数相同])dataset = MNIST( #数据集类别root=r'E:\MNIST数字识别\training', #将数据集存储在路径/files/内train =True, #True表示获取的是训练s集,否则获取的是训练集download=False,#如果没有下载过数据集,则需要标True下载transform= transform_fn #图片处理函数)data_loader = DataLoader( #分批次的数据集类别dataset, #将dataset分批batch_size = Batch_Size, #每批包含Batch_size个数据shuffle=True #随机打乱)return data_loader

Part 2. 构建模型

1. 构建模型逻辑

神经网络结构:
四层神经网络,输入层->全连接层1->全连接层2->输出层

参数设定:

激活函数: relu()

损失函数: 交叉熵损失函数

数据形状:

原始数据:[batch_size,1,28,28] # 原始数据形状

输入层:[batch_size,1x28x28] # 摊开

第一层输出:[batch_size,28] # 参数28可以自行修改

第二层输出:[batch_size,10] # 十个数字,十个类别

优化器: Adam()

import torch
from torch import nn
import torch.nn.functional as F
class Number_Identify(nn.Module):def __init__(self):super(Number_Identify,self).__init__() #继承父类init的参数self.fc1 = nn.Linear(1*28*28,28,bias=True) #第一层神经网络,输入维度为1*28*28,输出维度为28self.fc2 = nn.Linear(28,10,bias=True)#第二层神经网络,输入维度为28,输出维度为10def forward(self,x): #模型输入x,输出outx = x.view(-1,1*28*28) #view()相当于reshape(),参数为-1表示 根据情况自适应调整x = self.fc1(x) #经过第一层神经网络计算x = F.relu(x) #经过激活函数out = self.fc2(x) #经过第二层神经网络计算return F.log_softmax(out,dim=-1)

2. 模型实例化

from torch import optim
model = Number_Identify()#调用模型基底、
#criterion = nn.CrossEntropyLoss() #损失函数
optimizer = optim.Adam(model.parameters(),lr=1e-3) #优化器

Part 3. 训练模型

该模块内容完成的功能:

  1. 从MNIST数据集导入训练数据集
  2. 实现训练逻辑
  3. 将上述功能 API化

构建模型训练的逻辑过程

#训练函数
def train(epoch):model.train(mode=True) #当前模型设定为训练模式data_train = get_data(2,train=True) #获取训练数据,数据按每两个一组分批for idx,(data,target) in enumerate(data_train):optimizer.zero_grad() #清零梯度out = model(data) #向前计算:预测当前数据的结果loss = F.nll_loss(out,target) #计算带权损失loss.backward() #反向传播optimizer.step() #参数更新#训练进度展示:if idx%10000 == 0:print('\t batches[%d/%d],loss:%.6f' % (idx,len(data_train),loss.data)) 

模型训练

#训练
training_times = int(input("输入训练次数:"))
for epoch in range(training_times): #多次训练print("Train_epoch[%d/%d]:" % (epoch+1,training_times))train(epoch)print("=========Train_epoch[%d/%d] finished======" % (epoch+1,training_times))
print("======训练完成======")

Part 4. 保存模型

torch.save(model.state_dict(),r'E:\AI_Model_save\Number_Identify\model_net.pt'
)#保存模型
torch.save(optimizer.state_dict(),r'E:\AI_Model_save\Number_Identify\model_optimiter.pt'
)#保存优化器

Part 5. 模型使用

1. 加载模型

import os
import torch
if os.path.exists(r'E:\AI_Model_save\Number_Identify'):       model.load_state_dict(torch.load(r'E:\AI_Model_save\Number_Identify\model_net.pt')) #加载模型optimizer.load_state_dict(torch.load(r'E:\AI_Model_save\Number_Identify\model_optimiter.pt'))#加载优化器print("======成功调用=======")
else: print("路径错误")

2. 模型评估

编写test函数,实现模型评估的API
实现功能:批量预测+计算正确率
API部分

import numpy as np
import torch
def test(model,data_test):loss_list = []accuracy_list = []for idx,(Input,target) in enumerate(data_test):with torch.no_grad():#预测状态,不改变梯度参数output = model(Input) #批量预测cur_loss = F.nll_loss(output,target) #计算损失loss_list.append(cur_loss)pred = output.max(dim=-1)[-1] #批量预测结果accuracy = pred.eq(target).float().mean() #计算准确率accuracy_list.append(accuracy)#这里计算的是每个batch的正确率与损失return np.mean(accuracy_list),np.mean(loss_list)

导入测试数据 并测试

data_test = get_data(100,train=False) 
#读取测试数据集,每100个为一组进行测试
accuracy,loss = test(model,data_test) #测试数据
print("准确率:%.2f" % (accuracy*100),"%")
print("Loss:",loss)

3. 单图预测【待填】

After all:

用pytorch做了这个实战,个人感受是:
pytorch的数据预处理更加傻瓜式,但伴随的参数也更多,需要熟悉的API也更多,学起来更麻烦一些,成本更高,但使用起来的便利性更好;
相比之下,keras的数据预处理需要我们自己完成,学起来很简单,但用起来很麻烦(每次都要自己手写数据的预处理)


http://www.ppmy.cn/news/95919.html

相关文章

算法6.堆结构、堆排序、加强堆

算法|6.堆结构、堆排序、加强堆 1.比较器的定义 题意:定义一个学生类,分别实现对学生对象数组按年龄升序、按id降序、按名字的字典序、按id排序且id相同的年龄大的排在前边。 解题思路: 定义一个学生类定义一个实现了Comparator接口的类A…

C++ 学习 ::【基础篇:06】:C++ (inline)内联函数的介绍及其出现的意义【对比于 C语言宏函数】

本系列 C 相关文章 仅为笔者学习笔记记录,用自己的理解记录学习!C 学习系列将分为三个阶段:基础篇、STL 篇、高阶数据结构与算法篇,相关重点内容如下: 基础篇:类与对象(涉及C的三大特性等&#…

Java中的String数据类型,String类(字符串)详解

目录 第一章、String概述1)String是什么2)String长什么样3)String的构造方法(声明方式) 第二章、String类的详解1)String底层是什么2)字符串存储的内存原理/字符串常量池(String Constant Pool)3&#xff0…

1. 从JDK源码级别彻底刨析JVM类加载机制

JVM性能调优 1. 类加载的运行全过程1.1 加载1.2 验证1.3 准备1.4 解析 本文是按照自己的理解进行笔记总结,如有不正确的地方,还望大佬多多指点纠正,勿喷。 课程内容: 1、从java.exe开始讲透Java类加载运行全过程 2、从JDK源码级别剖析JVM核…

科技的力量:致敬全国科技工作者

在这个日新月异的时代,科技的力量正在改变着我们的生活。为了庆祝5月30日的全国科技者工作日,我们特地上线本次创作活动,向所有为科技进步付出辛勤努力的科技工作者们致敬。在这篇博文中,我们将通过讲述科技发展的故事、分享技术成…

洛谷01背包变形(P1858多人背包)

多人背包 文章目录 一、问题简述二、问题分析三、代码参考 一、问题简述 DD 和好朋友们要去爬山啦! 他们一共有 K 个人,每个人都会背一个包。这些包 的容量是相同的,都是 V。可以装进背包里的一共有 N 种物品,每种物品都有 给定…

【地铁上的面试题】--基础部分--数据结构与算法--数组和链表

零、章节简介 《数据结构与算法》是《地铁上的面试题》专栏的第一章,重点介绍了技术面试中不可或缺的数据结构和算法知识。数据结构是组织和存储数据的方式,而算法是解决问题的步骤和规则。 这一章的内容涵盖了常见的数据结构和算法,包括数组…

NodeJs服务链路追踪日志

(逆境给人宝贵的磨炼机会。仅有经得起环境考验的人,才能算是真正的强者。自古以来的伟人,大多是抱着不屈不挠的精神,从逆境中挣扎奋斗过来的。——松下幸之助) 服务链路追踪 服务的链路追踪指我们可以通过一个标记&am…