【深度学习】神经网络实战分类与回归任务

embedded/2025/2/1 7:25:43/

第一步 读取数据

①导入torch

import torch

②使用魔法命令,使它使得生成的图形直接嵌入到 Notebook 的单元格输出中,而不是弹出新的窗口来显示图形

%matplotlib inline

③读取文件

from pathlib import Path
import requestsDATA_PATH=Path("data")
PATH = DATA_PATH/"mnist"
PATH.mkdir(parents=True,exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH/FILENAME).exists():content = requests.get(URL+FILENAME).content(PATH/FILENAME).open("wb").write(content)

④使用 gzippickle 模块加载一个压缩的 pickle 文件 (mnist.pkl.gz)

(PATH / FILENAME).as_posix():将 Path 对象转换为 POSIX 路径字符串,适用于跨平台环境。

import pickle
import gzipwith gzip.open((PATH/FILENAME).as_posix(),"rb") as f:((x_train,y_train),(x_valid,y_valid),_) = pickle.load(f,encoding="latin-1")

第二步 主体部分

①自定义神经网络模型

import torch.nn.functional as F
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784,128)self.hidden2 = nn.Linear(128,256)self.out = nn.Linear(256,10)def forward(self,x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x

②定义获取数据的方法

shuffle代表洗牌

def get_data(train_ds,valid_ds,bs):return (DataLoader(train_ds,batch_size=bs,shuffle=True),DataLoader(valid_ds,batch_size=bs*2))

③定义获取模型的方法

torch.optim 是 PyTorch 中用于定义各种优化算法的模块

from torch import optimdef get_model():model = Mnist_NN()return model,optim.Adam(model.parameters(),lr=0.001)

④定义损失函数

注1:调用model(xb)时会自动进行前向计算(forward pass)

这是因为PyTorch的nn.Module类(即所有神经网络模型的基类)内部实现了对__call__方法的重载。当通过实例化一个继承自nn.Module的类来创建对象时,并调用该对象(如果model(xb)),实际上是调用了这个对象的__call__方法。而__cacll__方法负责调用forward方法

注2:F.entropy是PyTorch中用于计算交叉熵损失的函数,位于torch.nn.functional模块中

它结合了log_softmax和nll_loss(负对数似然损失),使得在分类任务中可以直接使用,而无需显示地应用log_softmax

F提供了很多用于构建神经网络的方法,包括激活函数、损失函数、卷积操作、池化操作等

注3

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    代码段解析

这一段用于优化模型。它包含了反向传播(计算梯度)和参数更新的过程,是模型训练的核心步骤。

loss.backward()

        反向传播:调用backward()方法会根据损失函数对模型参数进行自动求导,计算每个参数的梯度,这些梯度将被存储在对应的参数张量的.grad属性中

反向传播是基于链式法则自动计算所有参数相对于损失的偏导数的过程

这一步骤对于更新模型参数至关重要,因为它提供了调整参数所需的方向信息

opt.step()

        参数更新:调用step()方法会使用之前计算的梯度来更新模型参数。具体的更新规则取决于所使用的优化算法(如SGD、Adam等),并且可能涉及到学习率、动量等超参数

opt.zero_grad()

        清除梯度:调用zero_grad()方法会将所有参数的梯度重置为零。这是必要的,因为PyTorch默认会累积梯度,而不是每次前向传播后自动清除它们。

如果不重置梯度,旧的梯度将会与新的梯度相加,导致不正确的梯度值,进而影响参数更新的效果。

通常在每次迭代结束时调用此方法以确保下一次前向传播时梯度是从零开始计算的。

loss_func = F.cross_entropydef loss_batch(model,loss_func,xb,yb,opt=None):loss = loss_func(model(xb),yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(),len(xb)

⑤定义训练函数

import numpy as npdef fit(steps,model,loss_func,opt,train_dl,valid_dl):for step in range(steps):model.train()for xb,yb in train_dl:loss_batch(model,loss_func,xb,yb,opt)model.eval()with torch.no_grad():losses,nums = zip(*[loss_batch(model,loss_func,xb,yb) for xb,yb in valid_dl])val_loss = np.sum(np.multiply(losses,nums))/np.sum(nums)print("当前step:"+str(step)+",验证集损失:"+str(val_loss))

第三步 运行

①使用 Python 的内置 map() 函数,结合 PyTorch 的 torch.tensor 方法,将 x_train, y_train, x_valid, 和 y_valid 转换为 PyTorch 张量。这一步骤是数据预处理的一部分,确保所有数据都以张量的形式存储,从而可以直接用于 PyTorch 模型的训练和评估。

②加载数据集和数据

③加载模型

④训练,评估

x_train,y_train,x_valid,y_valid = map(torch.tensor,(x_train,y_train,x_valid,y_valid))train_ds = TensorDataset(x_train,y_train)
valid_ds = TensorDataset(x_valid,y_valid)
bs=64train_dl,valid_dl = get_data(train_ds,valid_ds,bs)
model,opt = get_model()
fit(25,model,loss_func,opt,train_dl,valid_dl)

运行结果:

测试训练精度:

correct = 0
total = 0
for xb,yb in valid_dl:outputs = model(xb)_,predicted = torch.max(outputs.data,1)total += yb.size(0)correct += (predicted==yb).sum().item()print("准确率为: %d %%" % (100*correct/total))


至此,该实战完成!


http://www.ppmy.cn/embedded/158570.html

相关文章

【C++】特殊类设计、单例模式与类型转换

目录 一、设计一个类不能被拷贝 (一)C98 (二)C11 二、设计一个类只能在堆上创建对象 (一)将构造函数私有化,对外提供接口 (二)将析构函数私有化 三、设计一个类只…

具身智能体空间感知基础!ROBOSPATIAL:评测并增强2D和3D视觉语言模型空间理解水平

作者:Chan Hee Song, Valts Blukis,Jonathan Tremblay, Stephen Tyree, Yu Su, Stan Birchfield 单位:俄亥俄州立大学,NVIDIA 论文标题:ROBOSPATIAL: Teaching Spatial Understanding to 2D and 3D Vision-Language Models for …

SpringSecurity:There is no PasswordEncoder mapped for the id “null“

文章目录 一、情景说明二、分析三、解决 一、情景说明 在整合SpringSecurity功能的时候 我先是去实现认证功能 也就是,去数据库比对用户名和密码 相关的类: UserDetailsServiceImpl implements UserDetailsService 用于SpringSecurity查询数据库 Logi…

Baklib如何变革企业知识管理提升工作效率与市场竞争力分析

内容概要 在当今数字化迅速发展的时代,企业面临着管理和运用知识资源的重大挑战。Baklib知识中台应运而生,成为企业提升知识管理的重要工具。通过构建一个集中化的平台,Baklib不仅使得知识的获取、分享和应用变得更加高效,同时也…

第23篇:Python开发进阶:详解测试驱动开发(TDD)

第23篇:测试驱动开发(TDD) 内容简介 在软件开发过程中,测试驱动开发(TDD,Test-Driven Development)是一种强调在编写实际代码之前先编写测试用例的开发方法。TDD不仅提高了代码的可靠性和可维…

pytorch实现门控循环单元 (GRU)

特性GRULSTM计算效率更快,参数更少相对较慢,参数更多结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖训练速度训练更快&am…

DeepSeek R1与OpenAI o1深度对比

文章目录 引言技术原理DeepSeek R1OpenAI o1 性能表现官方数据推理任务知识密集型任务通用能力 价格对比应用场景科研与技术开发自然语言处理(NLP)企业智能化升级教育与培训数据分析与智能决策 部署与集成DeepSeek R1OpenAI o1 伦理考量DeepSeek R1OpenA…

31. C语言 命令行参数

本章目录: 前言:什么是命令行参数?一个简单的示例运行结果 命令行参数的常见使用场景带空格的参数 高级命令行参数解析使用 getopt_long 的示例示例运行 注意事项进一步的实践:实现多功能程序总结 前言: 在 C 语言中,…