基于pytorch的手写数字识别-训练+使用

news/2024/10/7 21:39:23/
python">import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1)  # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

python">import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1)  # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()


http://www.ppmy.cn/news/1535874.html

相关文章

性能学习5:性能测试的流程

一.需求分析 二.性能测试计划 1&#xff09;测什么&#xff1f; - 项目背景 - 测试目的 - 测试范围 - ... 2&#xff09;谁来测试 - 时间进度与分工 - 交付清单 - ... 3&#xff09;怎么测 - 测试策略 - ... 三.性能测试用例 四.性能测试执行 五.性能分析和调优 六…

Ultralytics_yolov10目标检测,预处理函数入口

日期&#xff1a;2024.10.7. 随着Ultralytics的更新&#xff0c;yolov5-v11可以统一使用Ultralytics包体&#xff0c;我之前分析的yolov5关键代码定位在Ultralytics中不适用&#xff0c;这篇博客更新一下。 1. Ultralytics包体版本&#xff1a; $ pip list | grep ultralytic…

Redis --- 第二讲 --- 特性和安装

一、背景知识 Redis特性&#xff1a; Redis是一个在内存中存储数据的中间件&#xff0c;用于作为数据库&#xff0c;作为缓存&#xff0c;在分布式系统中能够大展拳脚。Redis的一些特性造就了现在的Redis。 在内存中存储数据&#xff0c;通过一系列的数据结构。MySQL主要是通…

yield:生成器 ----------------

yield&#xff1a;生成器 任何使用yield的函数都称之为生成器&#xff0c;如&#xff1a; def count(n):while n > 0:yield n #生成值&#xff1a;nn - 1另外一种说法&#xff1a;生成器就是一个返回迭代器的函数&#xff0c;与普通函数的区别是生成器包含yield语句&…

23.2 prometheus为k8s做的4大适配工作

本节重点介绍 : k8s监控中的4大采集类型总结prometheus为k8s监控做的4大适配工作 k8s关注指标分析 在监控每个细分的领域时&#xff0c;我们都要先思考下到底需要关注哪些方面的指标。k8s中组件复杂&#xff0c;我们主要专注的无外乎四大块指标&#xff1a;容器基础资源指标…

AR技术在电商行业的应用及优势有哪些?

AR&#xff08;增强现实&#xff09;技术在电商行业的应用广泛且深入&#xff0c;为消费者带来了全新的购物体验&#xff0c;同时也为商家带来了诸多优势。以下是AR技术在电商行业的主要应用场景及其优势&#xff1a; 一、应用场景 1、虚拟商品展示与试用 家具AR摆放&#x…

B 私域模式升级:开源技术助力传统经销体系转型

一、引言 1.1 研究背景 随着市场竞争加剧&#xff0c;传统经销代理体系面临挑战。同时&#xff0c;开源技术发展迅速&#xff0c;为 B 私域升级带来新机遇。在当今数字化时代&#xff0c;企业面临着日益激烈的市场竞争。传统的经销代理体系由于管理效率低下、渠道局限、库存压…

eNSP网络配置指南:IP设置、DNS、Telnet、DHCP与路由表管理

1.eNSP基本操作和路由器IP配置命令 登录设备&#xff1a;通过Console口或通过eNSP的Telnet/SSH客户端登录到设备。进入特权模式&#xff1a;输入system-view进入系统视图。接口配置&#xff1a; 进入接口视图&#xff0c;例如interface GigabitEthernet0/0/0。配置IP地址和子网…