基于MNIST数据集完成MindSpore模型训练的入门

server/2025/3/31 12:09:54/

1、下载并处理数据集
数据集对于模型训练非常重要,好的数据集可以有效提高训练精度和效率。示例中用到的MNIST数据集是由10类28∗28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。
在这里插入图片描述

MindSpore Vision套件提供了用于下载并处理MNIST数据集的Mnist模块,以下示例代码将数据集下载、解压到指定位置并进行数据处理:
本章节中的示例代码
依赖mindvision,可使用命令pip install mindvision安装。如本文档以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。
pip install mindvision
在这里插入图片描述

from mindvision.dataset import Mnist

下载并处理MNIST数据集

download_train = Mnist(path=“/home/ma-user/minist/MNIST_data”, split=“train”, batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)

download_eval = Mnist(path=“/home/ma-user/minist/MNIST_data”, split=“test”, batch_size=32, resize=32, download=True)

dataset_train = download_train.run()
dataset_eval = download_eval.run()

参数说明:
● path:数据集路径。
● split:数据集类型,支持train、 test、infer,默认为train。
● batch_size:每个训练批次设定的数据大小,默认为32。
● repeat_num:训练时遍历数据集的次数,默认为1。
● shuffle:是否需要将数据集随机打乱(可选参数)。
● resize:输出图像的图像大小,默认为32*32。
● download:是否需要下载数据集,默认为False。
下载的数据集文件的目录结构如下:
./mnist/
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte

2、创建模型
按照LeNet的网络结构,LeNet除去输入层共有7层,其中有2个卷积层,2个子采样层,3个全连接层。
在这里插入图片描述

MindSpore Vision套件提供了LeNet网络模型接口lenet, 定义网络模型如下:
from mindvision.classification.models import lenet

network = lenet(num_classes=10, pretrained=False)
3、定义损失函数和优化器
要训练神经网络模型,需要定义损失函数和优化器函数。
● 损失函数这里使用交叉熵损失函数SoftmaxCrossEntropyWithLogits。
● 优化器这里使用Momentum。
import mindspore.nn as nn

定义损失函数

net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=‘mean’)

定义优化器函数

net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)

Epoch:[ 0/ 10], step:[ 1875/ 1875], loss:[0.025/0.358], time:31.886 ms, lr:0.01000
Epoch time: 46736.737 ms, per step time: 24.926 ms, avg loss: 0.358
Epoch:[ 1/ 10], step:[ 1875/ 1875], loss:[0.022/0.061], time:29.132 ms, lr:0.01000
Epoch time: 23467.796 ms, per step time: 12.516 ms, avg loss: 0.061
Epoch:[ 2/ 10], step:[ 1875/ 1875], loss:[0.103/0.044], time:25.255 ms, lr:0.01000
Epoch time: 23860.372 ms, per step time: 12.726 ms, avg loss: 0.044
Epoch:[ 3/ 10], step:[ 1875/ 1875], loss:[0.082/0.035], time:20.321 ms, lr:0.01000
Epoch time: 24797.029 ms, per step time: 13.225 ms, avg loss: 0.035
Epoch:[ 4/ 10], step:[ 1875/ 1875], loss:[0.165/0.028], time:29.398 ms, lr:0.01000
Epoch time: 25361.749 ms, per step time: 13.526 ms, avg loss: 0.028
Epoch:[ 5/ 10], step:[ 1875/ 1875], loss:[0.014/0.024], time:23.179 ms, lr:0.01000
Epoch time: 25456.395 ms, per step time: 13.577 ms, avg loss: 0.024
Epoch:[ 6/ 10], step:[ 1875/ 1875], loss:[0.164/0.020], time:26.164 ms, lr:0.01000
Epoch time: 23664.757 ms, per step time: 12.621 ms, avg loss: 0.020
Epoch:[ 7/ 10], step:[ 1875/ 1875], loss:[0.001/0.018], time:23.159 ms, lr:0.01000
Epoch time: 24761.985 ms, per step time: 13.206 ms, avg loss: 0.018
Epoch:[ 8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:21.047 ms, lr:0.01000
Epoch time: 27926.245 ms, per step time: 14.894 ms, avg loss: 0.016
训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。
通过模型运行测试数据集得到的结果,验证模型的泛化能力:

  1. 使用model.eval接口读入测试数据集。
  2. 使用保存后的模型参数进行推理。
    acc = model.eval(dataset_eval)

print(“{}”.format(acc))
在这里插入图片描述

可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。
4、加载模型
from mindspore import load_checkpoint, load_param_into_net

#加载已经保存的用于测试的模型
param_dict = load_checkpoint(“./lenet/lenet-1_1875.ckpt”)
加载参数到网络中
load_param_into_net(network, param_dict)
在这里插入图片描述

验证模型

我们使用生成的模型进行单个图片数据的分类预测,具体步骤如下
● 被预测的图片会随机生成,每次运行结果可能会不一样。
● 代码使用了Tensor模块,有关张量Tensor的信息。
import numpy as np
from mindspore import Tensor
import matplotlib.pyplot as plt

mnist = Mnist(“/home/ma-user/minist/MNIST_data”, split=“train”, batch_size=6, resize=32)
dataset_infer = mnist.run()
ds_test = dataset_infer.create_dict_iterator()
data = next(ds_test)
images = data[“image”].asnumpy()
labels = data[“label”].asnumpy()

plt.figure()
for i in range(1, 7):
plt.subplot(2, 3, i)
plt.imshow(images[i-1][0], interpolation=“None”, cmap=“gray”)
plt.show()

#使用函数model.predict预测image对应分类
output = model.predict(Tensor(data[‘image’]))
predicted = np.argmax(output.asnumpy(), axis=1)

#输出预测分类与实际分类
print(f’Predicted: “{predicted}”, Actual: “{labels}”')
在这里插入图片描述

从上面的打印结果可以看出,预测值与目标值完全一致。


http://www.ppmy.cn/server/176852.html

相关文章

Redis 本地安装

首先安装: https://redis.io/docs/latest/operate/oss_and_stack/install/install-redis/install-redis-from-source/ 进入root目录 tar -xzvf redis-stable.tar.gz cd redis-stable make然后 install sudo make install最后可以直接启动 redis-server但是此时启…

【虚幻引擎UE5】SpawnActor生成Character实例不执行AI Move To,未初始化AIController的原因和解决方法

虚幻引擎版本:5.5.4 问题描述 刚创建的Third Person项目里,定义一个BP_Enemy蓝图,拖拽到场景中产生的实例会追随玩家,但SpawnActor产生的实例会固定不动。BP_Enemy蓝图具体设计如下: BP_Enemy的Event Graph ​​ 又定义…

【redis】在 Spring中操作 Redis

文章目录 基础设置依赖StringRedisTemplate库的封装 运行StringList删库 SetHashZset 基础设置 依赖 需要选择这个依赖 StringRedisTemplate // 后续 redis 测试的各种方法,都通过这个 Controller 提供的 http 接口来触发 RestController public class MyC…

[贪心算法]买卖股票的最佳时机 买卖股票的最佳时机Ⅱ K次取反后最大化的数组和 按身高排序 优势洗牌(田忌赛马)

1.买卖股票的最佳时机 暴力解法就是两层循环&#xff0c;找出两个差值最大的即可。 优化&#xff1a;在找最小的时候不用每次都循环一遍&#xff0c;只要在i向后走的时候&#xff0c;每次记录一下最小的值即可 class Solution { public:int maxProfit(vector<int>& p…

C++11QT复习

文章目录 QT C 培训Day1 环境安装和入门&#xff08;2025.03.05&#xff09;Qt 自带的编译器Qt 的编译脚本&#xff1a;qmake / CMake**示例&#xff1a;Test.pro 文件** Qt 的版本控制系统C 中的头文件C 中的命名空间C 中的编译、链接、运行 Day2 C语法和工程实践&#xff08;…

CAM350-14.6学习笔记-1:导入Gerber文件

CAM350-14.6学习笔记-1:导入Gerber文件 使用自动导入器导入Gerber1&#xff1a;导航栏Home下面的Import——Automatic Import——选择文件路径——Next2&#xff1a;设置每层的类型&#xff1a;3&#xff1a;设置叠层4&#xff1a;弹出层别显示框及Gerber显示 按照Allegro输出的…

【python】OpenCV—Template Matching

文章目录 1、功能描述2、原理分析3、代码实现4、效果展示5、完整代码6、涉及到的库函数7、参考 更多有趣的代码示例&#xff0c;可参考【Programming】 1、功能描述 基于 opencv-python 实现模板匹配算法 2、原理分析 算法流程 &#xff08;1&#xff09;滑动窗口 将模板图…

Moonlight-16B-A3B: 变革性的高效大语言模型,凭借Muon优化器打破训练效率极限

近日&#xff0c;由Moonshot AI团队推出的Moonlight-16B-A3B模型&#xff0c;再次在AI领域引发了广泛关注。这款全新的Mixture-of-Experts (MoE)架构的大型语言模型&#xff0c;凭借其创新的训练优化技术&#xff0c;特别是Muon优化器的使用&#xff0c;成功突破了训练效率的极…