昇思25天学习打卡营第02天 | 快速入门

devtools/2024/10/22 8:32:14/

昇思25天学习打卡营第02天 | 快速入门

文章目录

  • 昇思25天学习打卡营第02天 | 快速入门
    • 数据准备
    • 网络构建
    • 模型训练
    • 模型测试
    • 迭代数据集
    • 模型保存
    • 加载模型
    • 总结
    • 打卡

数据准备

MindSpore通过DatasetTransforms实现高效的数据预处理

使用download下载数据,并创建数据集对象:

python">from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

数据集中的数据可以通过create_tuple_iteratorcreate_dict_iterator进行访问:

python">for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")breakfor data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break

原始的数据通常不能满足需求,需要通过数据流水线(Data Processing Pipeline)指定map、batch、shuffle等操作进行处理,并将数据打包为指定大小的batch:

python">def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

网络构建

通过继承nn.Cell类,并重写__init__construct方法来自定义网络结构。

python"># Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

__init__方法中的dense_relu_sequential定义了网络的层次结构,由Dense和ReLU组成。
construct方法描述对输入数据的变换。

模型训练

一个模型的训练需要经过三个步骤:

  1. 正向计算:计算模型的预测值,并于正确标签求loss;
  2. 反向传播:利用自动微分机制,求模型参数对loss的梯度值;
    3.参数优化: 根据梯度值更新参数。

MindSpore使用函数式自动微分机制,因此需要实现:

  1. 定义正向计算函数:
python">loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits
  1. 使用value_and_grad获得梯度计算函数:
python">grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  1. 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化:
python">def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate( dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

模型测试

通过测试函数来评估模型的性能:

python">def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

迭代数据集

完整的遍历一次数据集成为一个epoch

python">epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

模型保存

通过save_checkpoint保存模型的参数:

python">mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

模型的加载分为两步:

  1. 重新实例化网络模型:
python">model = Network()
  1. 加载模型参数,并加载至模型上:
python">param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

加载的模型可以直接用于预测推理:

python">model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

总结

通过这一节的内容,对一个网络的诞生有了大概的认识,从原始数据到数据集的处理,简单网络结构的搭建,训练中自动微分机制的使用方法等都有了一定的了解。此外还有模型的保存与加载方法,为之后的深入学习奠定了基础。

打卡

在这里插入图片描述


http://www.ppmy.cn/devtools/57049.html

相关文章

【vueUse库Animation模块各函数简介及使用方法】

vueUse库是一个专门为Vue打造的工具库,提供了丰富的功能,包括监听页面元素的各种行为以及调用浏览器提供的各种能力等。其中的Browser模块包含了一些实用的函数,以下是这些函数的简介和使用方法: vueUse库Sensors模块各函数简介及使用方法 vueUseAnimation函数1. useInter…

emptyDir + initContainer实现ConfigMap的动态更新(K8s相关)

1. 絮絮叨叨 K8s部署服务时,一般都需要使用ConfigMap定义一些配置文件例如,部署分布式SQL引擎Presto,会在ConfigMap中定义coordinator、worker所需的配置文件以node.properties为例,node.environment和node.data-dir的值将由Helm…

基于STM32+华为云IOT设计的智能冰箱(华为云IOT)

文章目录 一、前言1.1 项目介绍【1】项目开发背景【2】设计实现的功能【3】项目硬件模块组成【4】摘要 1.2 设计思路1.3 系统功能总结1.4 开发工具的选择【1】设备端开发【2】上位机开发 二、部署华为云物联网平台2.1 物联网平台介绍2.2 开通物联网服务2.3 创建产品&#xff08…

《后端程序猿 · 基于 Lettuce 实现缓存容错策略》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 近期刚转战 CSDN,会严格把控文章质量,绝不滥竽充数,如需交流&#xff…

Log4j日志框架讲解(全面,详细)

目录 Log4j概述 log4j的架构(组成) Loggers Appenders Layouts 快速入门 依赖 java代码 日志的级别 log4j.properties 自定义Logger 总结: Log4j概述 Log4j是Apache下的一款开源的日志框架,通过在项目中使用 Log4J&…

Nginx详解-安装配置等

目录 一、引言 1.1 代理问题 1.2 负载均衡问题 1.3 资源优化 1.4 Nginx处理 二、Nginx概述 三、Nginx的安装 3.1 安装Nginx 3.2 Nginx的配置文件 四、Nginx的反向代理【重点】 4.1 正向代理和反向代理介绍 4.2 基于Nginx实现反向代理 4.3 关于Nginx的location路径…

使用 Java Swing 和 XChart 创建多种图表

在现代应用程序开发中,数据可视化是一个关键部分。本文将介绍如何使用 Java Swing 和 XChart 库创建各种类型的图表。XChart 是一个轻量级的图表库,支持多种类型的图表,非常适合在 Java 应用中进行快速的图表绘制。 1、环境配置 在开始之前&…

用pycharm进行python爬虫的步骤

使用 pycharm 进行 python 爬虫的步骤:下载并安装 pycharm。创建一个新项目。安装 requests 和 beautifulsoup 库。编写爬虫脚本,包括获取页面内容、解析 html 和提取数据的代码。运行爬虫脚本。保存和处理提取到的数据。 用 PyCharm 进行 Python 爬虫的…