Pytorch深度学习教程_10_神经网络训练

news/2025/3/29 4:28:12/

欢迎来到《深度学习保姆教程》系列的第九篇!在前面的几篇中,我们已经介绍了python基本用法,学习了梯度、激活函数、损失函数、优化算法等,在上一个教程中我们学习了搭建神经网络的nn模块,今天我们学习如何训练神经网络


目录

1 数据加载和预处理

(1)理解数据格式

数据加载

(2)数据清洗

(3)数据预处理

(4)数据划分

2 Loop

关键组件

3 评估指标

(1)分类指标

(2)回归指标

(3)选择合适的指标

4 模型保存和加载

(1)保存模型

保存整个模型

仅保存模型的状态字典

(2)加载模型

加载整个模型

 加载状态字典

 5 小结


1 数据加载和预处理

数据是机器学习模型的根本。你如何准备和处理数据会显著影响模型性能。

(1)理解数据格式

数据可以有多种格式:

  • CSV/Excel: 表格数据,包含行和列。
  • JSON: 键值对结构的数据。
  • 图像: 基于像素的表示。
  • 文本: 字符或单词的序列。
  • 音频/视频: 多通道的时间序列数据。
数据加载
  • : 使用像Pandas、NumPy和OpenCV这样的库进行高效的数据加载。
  • 文件格式: 处理不同的文件格式和编码。
  • 数据结构: 将数据转换为适当的数据结构(数组、张量)。
import pandas as pd
import numpy as np# 从CSV加载数据
data = pd.read_csv('data.csv')# 转换为numpy数组
data_array = data.to_numpy()

(2)数据清洗

  • 缺失值: 处理缺失数据(插补、删除)。
  • 异常值: 识别和处理异常值(移除、封顶、转换)。
  • 数据不一致: 修正错误和不一致性。
import pandas as pd# 处理缺失值
data = data.fillna(method='ffill')  # 用前一个值填充缺失值# 移除异常值
outlier_threshold = 100  # 假设阈值为100
data = data[data['column_name'] < outlier_threshold]

(3)数据预处理

  • 归一化: 将数值特征缩放到特定范围(0-1,-1到1)。
  • 标准化: 中心化和缩放特征,使其具有零均值和单位方差。
  • 特征编码: 将分类数据转换为数值格式(独热编码、标签编码)。
  • 特征提取: 从原始数据中提取相关特征。
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
data_scaled = scaler.fit_transform(data)

(4)数据划分

  • 训练集、验证集和测试集: 将数据划分为用于模型训练、评估和测试的子集。
from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

 

2 Loop

循环是机器学习的核心。它是将数据输入模型、计算误差并更新参数的迭代过程。

典型的训练循环包括以下步骤:

  • 数据加载: 从数据集中获取一批数据。
  • 前向传播: 将数据通过模型以获得预测。
  • 损失计算: 计算预测与真实值之间的差异。
  • 反向传播: 计算损失相对于模型参数的梯度。
  • 参数更新: 根据梯度使用优化器调整模型参数。
import torch# 假设你有一个模型、优化器和数据加载器
for epoch in range(num_epochs):for i, (inputs, labels) in enumerate(train_loader):# 将参数梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()
关键组件
  • 轮次 (Epoch)‌: 完整遍历整个数据集一次。
  • 批量大小 (Batch size)‌: 一次处理的样本数量。
  • 优化器 (Optimizer)‌: 用于更新模型参数的算法。
  • 损失函数 (Loss function)‌: 测量预测值与真实值之间误差的函数。

 

3 评估指标

评估指标是我们衡量模型性能的标尺。它们提供了关于模型优势和劣势的见解,帮助我们了解模型在未见数据上的泛化能力。

(1)分类指标

对于分类问题,常用的指标包括:

准确率 (Accuracy)‌: 正确预测的比例。

from sklearn.metrics import accuracy_scoreaccuracy = accuracy_score(y_true, y_pred)

 精确率 (Precision)‌: 被正确预测为正类的正类预测比例

from sklearn.metrics import precision_scoreprecision = precision_score(y_true, y_pred)

召回率 (Recall)‌: 实际正类中被正确识别的比例。

from sklearn.metrics import recall_scorerecall = recall_score(y_true, y_pred)

F1分数 (F1-score)‌: 精确率和召回率的调和平均值。

from sklearn.metrics import f1_scoref1 = f1_score(y_true, y_pred)

混淆矩阵 (Confusion matrix)‌: 一个总结分类算法性能的表格。

(2)回归指标

对于回归问题,常用的指标包括:

  • 均方误差 (Mean Squared Error, MSE)‌: 预测值与实际值之间平方差的平均值。
from sklearn.metrics import mean_squared_errormse = mean_squared_error(y_true, y_pred)
  • 平均绝对误差 (Mean Absolute Error, MAE)‌: 预测值与实际值之间绝对差的平均值。
from sklearn.metrics import mean_absolute_errormae = mean_absolute_error(y_true, y_pred)
  • 决定系数 (R-squared)‌: 模型解释的因变量方差的比例。
from sklearn.metrics import r2_scorer2 = r2_score(y_true, y_pred)

(3)选择合适的指标

指标的选择取决于问题和期望的结果:

  • 不平衡数据集: 精确率、召回率和F1分数可能比准确率更有信息量。
  • 异常值: MAE可能比MSE对异常值更稳健。
  • ROC曲线和AUC: 用于评估分类模型,特别是在不平衡数据集上。
  • 对数损失 (Log loss)‌: 衡量概率分类模型的性能。
  • 自定义指标: 为特定问题创建定制的指标。

 

4 模型保存和加载

保存和加载训练好的模型对于可重现性、部署和分享你的工作至关重要。PyTorch 提供了方便的工具来实现这一目的。

(1)保存模型

PyTorch 提供了两种主要的方法来保存模型:

保存整个模型

这种方法保留了模型的架构和参数。

import torchtorch.save(model, 'model.pth')

仅保存模型的状态字典

这会保存模型的参数,允许你将它们加载到不同的模型架构中(如果兼容的话)。

torch.save(model.state_dict(), 'model_params.pth')

(2)加载模型

要加载一个已保存的模型:

加载整个模型

loaded_model = torch.load('model.pth')

 加载状态字典

model = MyModel(*args, **kwargs)  # 创建模型实例
model.load_state_dict(torch.load('model_params.pth'))

 5 小结

本篇博客快速介绍神经网络训练的基本概念,包括数据加载处理、训练轮次、评估指标及模型保存加载。


http://www.ppmy.cn/news/1583160.html

相关文章

C++:类型推导规则 unsigned short + 1

在 C/C 中&#xff0c;整数提升&#xff08;Integer Promotion&#xff09; 规则决定了 vlan_id 1 的类型&#xff1a; unsigned short 的值在运算时会被 提升&#xff08;promote&#xff09; 到 int 或 unsigned int&#xff08;取决于平台&#xff09;。 默认情况下&#x…

深入理解Spring框架:核心概念与组成剖析

引言 在Java企业级开发领域&#xff0c;Spring框架无疑是当之无愧的王者。自2003年首次发布以来&#xff0c;Spring凭借其强大的功能、高度的灵活性和卓越的扩展性&#xff0c;已成为构建大型企业应用程序的首选框架。本文将深入探讨Spring框架的核心概念与多样组成部分&#…

Docker+Ollama+Xinference+RAGFlow+Dify部署及踩坑问题

目录 一、Xinference部署 &#xff08;一&#xff09;简介 &#xff08;二&#xff09;部署 &#xff08;三&#xff09;参数 &#xff08;四&#xff09;错误问题 二、下载Reranker模型 &#xff08;一&#xff09;huggingface下载 &#xff08;二&#xff09;modelco…

70款街头涂鸦手绘乱涂乱画线条png免抠图设计素材Scribble Asset Pack Includes 70 Assets

70款街头涂鸦手绘乱涂乱画线条png免抠图设计素材Scribble Asset Pack Includes 70 Assets 这只是一套漂亮的涂鸦和涂鸦包&#xff0c;供您在下一个设计中使用&#xff01;该包包含 70 个 PNG 文件/资产。

WPS的PPT智能图形增加项目

WPS新建了一页PPT&#xff0c;在这页PPT里增加智能图形&#xff0c;如何增加某个项目的数量。 比如原始是三个文本框&#xff0c;现在改成四个文本框&#xff0c;免去自己在原始图形上进行修改的麻烦。 方法如下&#xff1a; 通过以下选中要增加数量的项目&#xff0c;会弹出几…

《时间编码》

第一章&#xff1a;奇怪的文件 深夜&#xff0c;程序员苏晨坐在自己的工位上&#xff0c;盯着屏幕上那个奇怪的文件发呆。 这是他刚刚从公司服务器上下载的一个压缩包&#xff0c;文件名是“20230921.zip”。他本以为是某个同事上传的测试文件&#xff0c;可解压后却发现&…

【BFS】《BFS 攻克 FloodFill:填平图形世界的技术密码》

文章目录 前言例题一、 图像渲染二、 岛屿数量三、岛屿的最大面积四、被围绕的区域 结语 前言 什么是BFS&#xff1f; BFS&#xff08;Breadth - First Search&#xff09;算法&#xff0c;即广度优先搜索算法&#xff0c;是一种用于图或树结构的遍历算法。以下是其详细介绍&am…

特发性手抖是一种常见的神经系统问题

特发性手抖是一种常见的神经系统问题&#xff0c;主要症状为无意识的手部颤动。对于这种状况&#xff0c;护理是非常重要的。以下是一些特发性手抖的护理方法&#xff1a; 1. 保持积极心态&#xff1a;良好的心态对疾病的康复非常重要。应尽量保持心情舒畅&#xff0c;避免过度…