【Pytorch】优化器(Optimizer)模块‘torch.optim’

server/2024/12/2 13:35:02/

torch.optim 是 PyTorch 中提供的优化器(Optimizer)模块,用于优化神经网络模型的参数,更新网络权重,使得模型在训练过程中最小化损失函数。它提供了多种常见的优化算法,如 梯度下降法(SGD)AdamAdagradRMSprop 等,用户可以根据需要选择合适的优化方法。

目录

      • 优化器的工作原理
      • `torch.optim` 中的常见优化器
      • 常用优化器参数
      • 优化器的基本使用方法
      • 完整示例
      • 总结

优化器的工作原理

优化器通过计算损失函数对模型参数的梯度(通常使用反向传播算法),然后根据优化算法的规则更新模型的参数,以逐步减少损失函数的值。具体更新规则取决于所选的优化算法。

torch.optim 中的常见优化器

  1. SGD(Stochastic Gradient Descent)

    • SGD 是最基本的优化算法,它通过计算损失函数的梯度,并按某个学习率(learning rate)更新模型的参数。
    • 可以选择是否使用动量(momentum)来加速收敛。

    示例

    python">optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
  2. Adam(Adaptive Moment Estimation)

    • Adam 是一种结合了动量法(Momentum)和自适应学习率(AdaGrad)的优化算法。它会分别对每个参数维护一个一阶矩估计(梯度的平均值)和二阶矩估计(梯度的平方的平均值),从而自适应地调整每个参数的学习率。
    • Adam 通常比 SGD 更常用于深度学习中的优化,尤其是在处理大规模数据时。

    示例

    python">optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
  3. Adagrad(Adaptive Gradient Algorithm)

    • Adagrad 是一种自适应优化算法,它为每个参数分配不同的学习率,并根据每个参数的梯度历史调整学习率。梯度大的参数会减小学习率,而梯度小的参数会增大学习率。

    示例

    python">optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)
    
  4. RMSprop(Root Mean Square Propagation)

    • RMSprop 是 Adagrad 的一种变体,旨在解决 Adagrad 学习率过早衰减的问题。它使用指数衰减的平均来计算梯度的平方,从而避免了梯度下降时过早减小学习率。

    示例

    python">optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)
    
  5. AdamW(Adam with Weight Decay)

    • AdamW 是 Adam 优化器的一个变种,加入了权重衰减(weight decay),用来防止模型过拟合。它与标准的 Adam 不同之处在于,它在参数更新过程中将权重衰减项分离出来,避免了标准 Adam 中衰减项的负面影响。

    示例

    python">optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    
  6. LBFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)

    • LBFGS 是一种二阶优化方法,它使用目标函数的二阶导数(Hessian 矩阵的近似)来加速收敛。与其他一阶方法相比,它在计算和内存使用上比较昂贵,但在某些特定问题中(如小批量数据和二次优化问题)能够提供更快的收敛速度。

    示例

    python">optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1)
    

常用优化器参数

每个优化器通常会接受以下几个参数:

  • params:待优化的参数(通常是模型的权重),可以使用 model.parameters() 获取。
  • lr(Learning Rate):学习率,控制每次参数更新的步长。较小的学习率可能导致收敛过慢,较大的学习率可能导致发散。
  • momentum(可选):用于动量的参数,通常用来加速收敛。
  • weight_decay(可选):L2 正则化系数,用于防止模型过拟合。
  • betas(Adam 和一些其他优化器):用于控制一阶矩(梯度的均值)和二阶矩(梯度的方差)衰减率的超参数。

优化器的基本使用方法

  1. 创建优化器
    通常在定义了模型后,通过 torch.optim 创建一个优化器,并将模型的参数传递给优化器。

    python">optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
  2. 梯度清零
    在每次迭代前,需要将模型参数的梯度清零,避免梯度累积。

    python">optimizer.zero_grad()
    
  3. 计算梯度
    使用反向传播计算梯度。

    python">loss.backward()
    
  4. 更新参数
    调用 step() 方法,根据计算出的梯度更新模型的参数。

    python">optimizer.step()
    

完整示例

下面是一个完整的使用优化器的示例:

python">import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型
model = SimpleNet()# 创建优化器(使用 Adam 优化器)
optimizer = optim.Adam(model.parameters(), lr=0.001)# 假设有一些输入数据和目标标签
input_data = torch.randn(5, 10)  # 输入数据:5个样本,每个样本10维
target = torch.randn(5, 1)       # 目标标签:5个样本,每个样本1维# 定义损失函数
criterion = nn.MSELoss()# 训练过程
for epoch in range(100):  # 训练 100 次# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target)# 清零梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()# 打印每个 epoch 的损失if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')

总结

  • torch.optim 提供了多种优化器(如 SGD、Adam、RMSprop 等)用于训练神经网络,用户可以选择合适的优化器来优化模型的参数。
  • 常见的优化器包括 Adam(适应性调整学习率)、SGD(随机梯度下降)、RMSpropAdagrad 等,选择哪个优化器取决于你的任务、模型和实验。
  • 优化器的核心工作流程包括:清零梯度、计算梯度、反向传播、更新参数。

选择合适的优化器和调优超参数(如学习率)是深度学习训练的一个关键部分。


http://www.ppmy.cn/server/146731.html

相关文章

Apache storm安装教程(单机版)

本章教程基于linux centos7安装Apache storm 单机版。 jdk版本:11.0.25 python版本:3.6.8 ZooKeeper 版本:3.7.2 Apache storm版本:2.7.1 一、Apache storm简介 Apache Storm 是一个分布式实时计算系统,专为处理大规模流数据而设计。它允许开发者实时处理大数据流,并对数…

SQL语句在MySQL中的执行过程

一 MySQL 基础架构分析 1.1 MySQL 基本架构概览 下图是 MySQL 的一个简要架构图,从下图你可以很清晰的看到用户的 SQL 语句在 MySQL 内部是如何执行的。 先简单介绍一下下图涉及的一些组件的基本作用帮助大家理解这幅图,在 1.2 节中会详细介绍到这些组…

HarmonyOS4+NEXT星河版入门与项目实战(25)------UIAbility启动模式(文档编辑案例)

文章目录 1、启动模式2、Specified启动模式实现步骤3、文档编辑案例1、文件创建2代码实现3、Statge 创建4、添加配置1、启动模式 Singleton启动模式: 每个 UIAbility 只存在一个实例,是默认的启动模式,任务列表中只会存在一个相同的 UIAbilityStandard启动模式: 每次启动 U…

RabbitMQ在手动消费的模式下设置失败重新投递策略

最近在写RabbitMQ的消费者,因为业务需求,希望失败后重试一定次数,超过之后就不处理了,或者放入死信队列。我这里就达到重试次数后就不处理了。本来以为很简单的,问了kimi,按它的方法配置之后,发…

C++20: 像Python一样逐行读取文本文件并支持切片操作

概要 逐行读取文本文件,并提取其中连续的几行,这对于 Python 来说是小菜一碟。 C 则很笨拙, 语言不自带这些。 这次我来拯救 C boys & girls, 在 C20 环境下,山寨一个 Python 下的逐行读文本文件、支持 slice 操作…

【C++】 list接口以及模拟实现

list介绍 list文档介绍 C中的list是一个双向链表容器。它允许在任意位置进行快速插入和删除操作,并且能够在常量时间内访问任意元素,并且该容器可以前后双向迭代。 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容…

arcgis for js FeatureLayer和GeoJSON一个矢量点同时渲染图形和文本

效果 FeatureLayer和GeoJSONLayer, 一个矢量点同时渲染图形和文本 代码 相关参数自行查阅文档, 这里就不做注释了 示例代码手动创建FeatureLayer方式, 如果是通过远程url加载图层的 渲染方式同理, GeoJSONLayer同理 <!DOCTYPE html> <html lang"zn"><…

[极客时间]AIGC产品经理训练营毕业总结

为期10周的训练营也进入尾声了,回想当初&#xff0c;真的是蛮感慨的。 我为什么会去参加AIGC产品训练营呢? 其实也是蛮奇妙的. 我是一名传统业务的前端研发, 因为一个项目第一次真实的接触到了AIGC,就感觉像是打开了新世界的大门,让我倍感兴奋,想要在ai的世界里探索一番. 不过…