优化器调整策略

news/2024/10/17 19:25:08/

损失函数的作用是衡量模型输出与真实标签的差异。当我们有了这个loss之后,我们就可以通过反向传播机制得到参数的梯度,那么我们如何利用这个梯度进行更新参数使得模型的loss逐渐的降低呢?

优化器的作用

Pytorch的优化器: 管理更新模型中可学习参数的值, 使得模型输出更接近真实标签。

Optimizer的基本属性

在这里插入图片描述

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
  • defaults: 优化器超参数,里面会存储一些学习率, momentum的值,衰减系数等
  • state: 参数的缓存, 如momentum的缓存(使用前几次梯度进行平均)
  • param_groups: 管理的参数组, 这是个列表,每一个元素是一个字典,在字典中有key,key里面的值才是我们真正的参数(这个很重要, 进行参数管理)
  • _step_count: 记录更新次数, 学习率调整中使用, 比如迭代100次之后更新学习率的时候,就得记录这里的100.

Optimizer的基本方法

在这里插入图片描述

  • zero_grad()梯度清零。清空所管理参数的梯度, 这里注意Pytorch有一个特性就是张量梯度不自动清零
  • step(): 执行一步更新
  • add_param_group(): 添加参数组, 我们知道优化器管理很多参数,这些参数是可以分组的,我们对不同组的参数可以设置不同的超参数, 比如模型finetune中,我们希望前面特征提取的那些层学习率小一些,而后面我们新加的层学习率大一些更新快一点,就可以用这个方法。
  • state_dict(): 获取优化器当前状态信息字典
  • load_state_dict(): 加载状态信息字典,这两个方法用于模型断点的一个续训练, 所以我们在模型训练的时候,一般多少个epoch之后就要保存当前的状态信息。
  • 在这里插入图片描述
    这里就是optimizer的__init__初始化部分了,可以看到上面介绍的那几个属性和它们的初始化方法,当然这里有个最重要的就是参数组的添加,我们看看是怎么添加的
    在这里插入图片描述
    这里重点说一下这个,我们还记得初始化SGD的时候传入了一个形参:optim.SGD(model.parameters(), lr=LR, momentum=0.9),这里的model.parameters() 就是神经网络的每层的参数, SGD在初始化的时候, 会把这些参数以参数组的方式再存起来, 上图中的params就是神经网络每一层的参数。

def __init__(self, params, defaults):这里的params其实就是实参model.parameters() 传入进来的
这就是优化器的初始化工作了, 初始化完了之后, 我们就可以进行梯度清空,然后更新梯度即可:
在这里插入图片描述

动量

Momentum:结合当前梯度与上一次更新信息, 用于当前更新。这么说可能有点抽象, 那么我们可以举个比较形象的例子:
在这里插入图片描述

指数加权平均在时间序列中经常用于求取平均值的一个方法,它的思想是这样,我们要求取当前时刻的平均值,距离当前时刻越近的那些参数值,它的参考性越大,所占的权重就越大,这个权重是随时间间隔的增大呈指数下降,所以叫做指数滑动平均。公式如下:

在这里插入图片描述
vt 是当前时刻的一个平均值,这个平均值有两项构成

  • 一项是当前时刻的参数值θt, 所占的权重是1 − β , 这个β是个参数。
  • 另一项是上一时刻的一个平均值, 权重是β。

假设我想求第100天温度的一个平均值,那么根据上面的公式:
在这里插入图片描述
我们发现,距离当前时刻越远的那些 θ 值,它的权重是越来越小的,因为 β 小于1, 所以间隔越远,小于1的这些数连乘,权重越来越小,而且是呈指数下降,因为这里是βi 。

Momentum梯度下降:
当前梯度的更新量会考虑到之前梯度, 上一时刻的梯度,前一时刻的梯度,这样一直往前,只不过越往前权重越小而已。

model.state_dict 和 optimizer.state_dict


http://www.ppmy.cn/news/1063101.html

相关文章

Java 时间日期处理,工作必用(建议收藏)

工作中经常会遇到对时间日期进行处理的业务,像日期类的API个人觉得不需要背,需要的时候去查资料就行。我整理了Java8之前及之后日期类常用的时间日期处理方法,方便工作需要时查找,觉得有用的朋友可以收藏。 一、日期格式化和解析 …

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac,所以在安装的时候一上手就选择了virutalbox,结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看,才知…

《认知觉醒》读书笔记之潜意识

模糊--人生是一场消除模糊的比赛。 学习知识,消除认知模糊 掌握的工具越多,认知能力越强,消除模糊的能力就越强。 元认知-----》 如何反观自己。 刻意练习----》 如何精进自己。 运动改造大脑---》 如何激化自己的运动热情。 学习知识的…

【安装GPU版本pytorch,torch.cuda.is_available()仍然返回False问题】

TOC 第一步 检查cuda是否安装,CUDA环境变量是否正确设置,比如linux需要设置在PATH,window下环境变量编辑看看,是否有CUDA 第二步,核查python中torch版本 首先查看你环境里的pytorch是否是cuda版本,我这…

【力扣】盛最多水的容器

目录 题目 题目初步解析 水桶效应 代码实现逻辑 第一步 第二步 第三步 代码具体实现 注意 添加容器元素的函数 计算迭代并且判断面积是否是最大值 总代码 运行结果 总结 题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是…

4.16 TCP 协议有什么缺陷?

目录 升级 TCP 的工作很困难 TCP 建立连接的延迟 TCP 存在队头阻塞问题 网络迁移需要重新建立 TCP 连接 升级 TCP 的工作很困难;TCP 建立连接的延迟;TCP 存在队头阻塞问题;网络迁移需要重新建立 TCP 连接; 升级 TCP 的工作很…

adb使用总结

adb连接到模拟器 adb devices 打开模拟器,找到设置。 多次点击版本号,切换到开发者模式 搜索进入开发者选项 开启USB调试 此时在终端输入adb devices就连接上了 使用adb查看安卓手机架构 adb shell getprop ro.product.cpu.abi 进入安卓手机的shell …

五、多表查询-3.4连接查询-联合查询union

一、概述 二、演示 【例】将薪资低于5000的员工,和 年龄大于50岁的 员工全部查询出来 1、查询薪资低于5000的员工 2、查询年龄大于50岁的员工 3、将薪资低于5000的员工,和 年龄大于50岁的 员工全部查询出来(把上面两部分的结果集直接合并起…