1、下载并处理数据集
数据集对于模型训练非常重要,好的数据集可以有效提高训练精度和效率。示例中用到的MNIST数据集是由10类28∗28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。
MindSpore Vision套件提供了用于下载并处理MNIST数据集的Mnist模块,以下示例代码将数据集下载、解压到指定位置并进行数据处理:
本章节中的示例代码
依赖mindvision,可使用命令pip install mindvision安装。如本文档以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。
pip install mindvision
from mindvision.dataset import Mnist
下载并处理MNIST数据集
download_train = Mnist(path=“/home/ma-user/minist/MNIST_data”, split=“train”, batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
download_eval = Mnist(path=“/home/ma-user/minist/MNIST_data”, split=“test”, batch_size=32, resize=32, download=True)
dataset_train = download_train.run()
dataset_eval = download_eval.run()
参数说明:
● path:数据集路径。
● split:数据集类型,支持train、 test、infer,默认为train。
● batch_size:每个训练批次设定的数据大小,默认为32。
● repeat_num:训练时遍历数据集的次数,默认为1。
● shuffle:是否需要将数据集随机打乱(可选参数)。
● resize:输出图像的图像大小,默认为32*32。
● download:是否需要下载数据集,默认为False。
下载的数据集文件的目录结构如下:
./mnist/
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
2、创建模型
按照LeNet的网络结构,LeNet除去输入层共有7层,其中有2个卷积层,2个子采样层,3个全连接层。
MindSpore Vision套件提供了LeNet网络模型接口lenet, 定义网络模型如下:
from mindvision.classification.models import lenet
network = lenet(num_classes=10, pretrained=False)
3、定义损失函数和优化器
要训练神经网络模型,需要定义损失函数和优化器函数。
● 损失函数这里使用交叉熵损失函数SoftmaxCrossEntropyWithLogits。
● 优化器这里使用Momentum。
import mindspore.nn as nn
定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=‘mean’)
定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
Epoch:[ 0/ 10], step:[ 1875/ 1875], loss:[0.025/0.358], time:31.886 ms, lr:0.01000
Epoch time: 46736.737 ms, per step time: 24.926 ms, avg loss: 0.358
Epoch:[ 1/ 10], step:[ 1875/ 1875], loss:[0.022/0.061], time:29.132 ms, lr:0.01000
Epoch time: 23467.796 ms, per step time: 12.516 ms, avg loss: 0.061
Epoch:[ 2/ 10], step:[ 1875/ 1875], loss:[0.103/0.044], time:25.255 ms, lr:0.01000
Epoch time: 23860.372 ms, per step time: 12.726 ms, avg loss: 0.044
Epoch:[ 3/ 10], step:[ 1875/ 1875], loss:[0.082/0.035], time:20.321 ms, lr:0.01000
Epoch time: 24797.029 ms, per step time: 13.225 ms, avg loss: 0.035
Epoch:[ 4/ 10], step:[ 1875/ 1875], loss:[0.165/0.028], time:29.398 ms, lr:0.01000
Epoch time: 25361.749 ms, per step time: 13.526 ms, avg loss: 0.028
Epoch:[ 5/ 10], step:[ 1875/ 1875], loss:[0.014/0.024], time:23.179 ms, lr:0.01000
Epoch time: 25456.395 ms, per step time: 13.577 ms, avg loss: 0.024
Epoch:[ 6/ 10], step:[ 1875/ 1875], loss:[0.164/0.020], time:26.164 ms, lr:0.01000
Epoch time: 23664.757 ms, per step time: 12.621 ms, avg loss: 0.020
Epoch:[ 7/ 10], step:[ 1875/ 1875], loss:[0.001/0.018], time:23.159 ms, lr:0.01000
Epoch time: 24761.985 ms, per step time: 13.206 ms, avg loss: 0.018
Epoch:[ 8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:21.047 ms, lr:0.01000
Epoch time: 27926.245 ms, per step time: 14.894 ms, avg loss: 0.016
训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。
通过模型运行测试数据集得到的结果,验证模型的泛化能力:
- 使用model.eval接口读入测试数据集。
- 使用保存后的模型参数进行推理。
acc = model.eval(dataset_eval)
print(“{}”.format(acc))
可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。
4、加载模型
from mindspore import load_checkpoint, load_param_into_net
#加载已经保存的用于测试的模型
param_dict = load_checkpoint(“./lenet/lenet-1_1875.ckpt”)
加载参数到网络中
load_param_into_net(network, param_dict)
验证模型
我们使用生成的模型进行单个图片数据的分类预测,具体步骤如下
● 被预测的图片会随机生成,每次运行结果可能会不一样。
● 代码使用了Tensor模块,有关张量Tensor的信息。
import numpy as np
from mindspore import Tensor
import matplotlib.pyplot as plt
mnist = Mnist(“/home/ma-user/minist/MNIST_data”, split=“train”, batch_size=6, resize=32)
dataset_infer = mnist.run()
ds_test = dataset_infer.create_dict_iterator()
data = next(ds_test)
images = data[“image”].asnumpy()
labels = data[“label”].asnumpy()
plt.figure()
for i in range(1, 7):
plt.subplot(2, 3, i)
plt.imshow(images[i-1][0], interpolation=“None”, cmap=“gray”)
plt.show()
#使用函数model.predict预测image对应分类
output = model.predict(Tensor(data[‘image’]))
predicted = np.argmax(output.asnumpy(), axis=1)
#输出预测分类与实际分类
print(f’Predicted: “{predicted}”, Actual: “{labels}”')
从上面的打印结果可以看出,预测值与目标值完全一致。