pytorch对不同的可调参数,分配不同的学习率

devtools/2024/10/21 6:06:32/

在 PyTorch 中,你可以通过为优化器传递不同的学习率来针对不同的可调参数分配不同的学习率。这通常通过向优化器传递一个字典列表来实现,其中每个字典指定特定参数组的学习率。下面是一个示例代码,展示了如何实现这一点:

import torch
import torch.optim as optim# 假设我们有两个模型参数:param1 和 param2
param1 = torch.nn.Parameter(torch.randn(2, 3))
param2 = torch.nn.Parameter(torch.randn(3, 4))# 将这些参数分配给不同的学习
optimizer = optim.SGD([{'params': param1, 'lr': 0.01},{'params': param2, 'lr': 0.001}
], lr=0.01, momentum=0.9)# 模拟一次训练步骤
loss = (param1.sum() + param2.sum()) ** 2
loss.backward()
optimizer.step()# 打印更新后的参数值
print(param1)
print(param2)

对于余弦退火算法中,对于可调的学习率,pytorch对不同的可调参数,分配不同的学习率权重

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR# 假设我们有两个模型参数:param1 和 param2
param1 = torch.nn.Parameter(torch.randn(2, 3))
param2 = torch.nn.Parameter(torch.randn(3, 4))# 为每个参数组分配不同的学习
optimizer = optim.SGD([{'params': param1, 'lr': 0.01},{'params': param2, 'lr': 0.001}
], lr=0.01, momentum=0.9)# 为整个优化器设置余弦退火调度器
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001)# 模拟一个训练周期
for epoch in range(10):# 执行优化步骤loss = (param1.sum() + param2.sum()) ** 2loss.backward()optimizer.step()# 更新学习scheduler.step()# 打印当前学习for i, param_group in enumerate(optimizer.param_groups):print(f'Epoch {epoch+1}, Param Group {i+1}: Learning Rate = {param_group["lr"]}')

两个参数先后优化,第一阶段主要优化param1,后一阶段主要优化param2

方法1:分阶段调整优化器的参数组
你可以在第一阶段只优化 param1,然后在第二阶段只优化 param2。这可以通过在不同阶段将 param1 或 param2 从优化器中移除或冻结(将学习率设置为 0)来实现。

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR# 假设我们有两个模型参数:param1 和 param2
param1 = torch.nn.Parameter(torch.randn(2, 3))
param2 = torch.nn.Parameter(torch.randn(3, 4))# 第一阶段:仅优化 param1
optimizer1 = optim.SGD([{'params': param1, 'lr': 0.01}], momentum=0.9)
scheduler1 = CosineAnnealingLR(optimizer1, T_max=5, eta_min=0.0001)# 第二阶段:仅优化 param2
optimizer2 = optim.SGD([{'params': param2, 'lr': 0.001}], momentum=0.9)
scheduler2 = CosineAnnealingLR(optimizer2, T_max=5, eta_min=0.0001)# 模拟训练
for epoch in range(10):# 第一阶段:前5个epoch优化param1if epoch < 5:optimizer1.zero_grad()loss = (param1.sum()) ** 2loss.backward()optimizer1.step()scheduler1.step()print(f'Epoch {epoch+1}: Optimizing param1, LR = {scheduler1.get_last_lr()}')# 第二阶段:后5个epoch优化param2else:optimizer2.zero_grad()loss = (param2.sum()) ** 2loss.backward()optimizer2.step()scheduler2.step()print(f'Epoch {epoch+1}: Optimizing param2, LR = {scheduler2.get_last_lr()}')

方法2:同时设置不同的学习率,但不同阶段侧重不同的参数
在这个方法中,你可以在第一阶段为 param1 设置较大的学习率,param2 设置为非常小的学习率(几乎不变)。然后在第二阶段反过来。

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR# 假设我们有两个模型参数:param1 和 param2
param1 = torch.nn.Parameter(torch.randn(2, 3))
param2 = torch.nn.Parameter(torch.randn(3, 4))# 同时优化param1和param2,但不同阶段有不同的学习
optimizer = optim.SGD([{'params': param1, 'lr': 0.01},  # param1初始学习率较大{'params': param2, 'lr': 0.0001}  # param2初始学习率较小
], momentum=0.9)scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.00001)# 模拟训练
for epoch in range(10):optimizer.zero_grad()# 计算损失loss = (param1.sum() + param2.sum()) ** 2loss.backward()optimizer.step()scheduler.step()# 不同阶段调整学习if epoch == 5:optimizer.param_groups[0]['lr'] = 0.0001  # param1 学习率降低optimizer.param_groups[1]['lr'] = 0.01    # param2 学习率增大# 打印学习print(f'Epoch {epoch+1}: LR for param1 = {optimizer.param_groups[0]["lr"]}, LR for param2 = {optimizer.param_groups[1]["lr"]}')

http://www.ppmy.cn/devtools/111132.html

相关文章

Python 多线程

开始学习Python线程 线程模块 使用Threading模块创建线程 线程同步 线程优先级队列&#xff08; Queue&#xff09; 多线程类似于同时执行多个不同程序&#xff0c;多线程运行有如下优点&#xff1a; 使用线程可以把占据长时间的程序中的任务放到后台去处理。用户界面可以更…

Pyspark下操作dataframe方法(1)

文章目录 Pyspark dataframe创建DataFrame使用Row对象使用元组与scheam使用字典与scheam注意 agg 聚合操作alias 设置别名字段设置别名设置dataframe别名 cache 缓存checkpoint RDD持久化到外部存储coalesce 设置dataframe分区数量collect 拉取数据columns 获取dataframe列 Pys…

苹果宣布iOS 18正式版9月17日推送:支持27款iPhone升级

9月10日消息&#xff0c;在苹果秋季发布会结束后&#xff0c; 苹果宣布将于9月17日(下周二)推送iOS 18正式版系统。 苹果官网显示&#xff0c;iOS 18正式版将兼容第二代iPhone SE及之后的所有机型&#xff0c;加上刚发布的iPhone 16系列&#xff0c;共兼容27款iPhone。 iOS 18升…

【数据获取与读取】JSON CSV

数据分析流程 获取数据-读取数据-评估数据-清洗数据-整理数据-分析数据-可视化数据 公开数据集 飞桨&#xff08;百度旗下深度学习平台&#xff09;数据集&#xff1a;https:/aistudio.baidu.com/aistudio/datasetoverview 天池&#xff08;阿里云旗下开发者竞赛平台&#xf…

​了解MySQL 的二进制日志文件​Binlog

1. SQL 语句的几种类型 首先介绍一下&#xff0c;对于一个 SQL 语句&#xff0c;它常常被分为以下几种类型&#xff1a; DDL&#xff08;Data Definition Language&#xff0c;数据定义语言&#xff09;&#xff1a;用来操作数据库、表、列等&#xff0c;比如 CREATE、ALTER…

现在有一台ubuntu22.04 的工作站机器,现在想通过RDP的方式进行远程开发

在 Ubuntu 22.04 工作站上通过 RDP&#xff08;远程桌面协议&#xff09;进行连接的具体步骤如下&#xff1a; 1. 安装 RDP 服务 Ubuntu 默认不支持 RDP 连接&#xff0c;因此你需要安装一个 RDP 服务器&#xff0c;通常使用 xrdp 这个软件包。 步骤&#xff1a; 打开终端&a…

核心知识点合集

不断补充... 知识模块文章详情心得&资料JVM相关 Java&#xff1a;简述JDK&#xff0c;JRE&#xff0c;JVM之间的关系JVM&#xff1a;简述JVM内存分配模型JVM&#xff1a;浅谈垃圾回收GC机制JVM&#xff1a;浅谈内存溢出的原因JVM&#xff1a;浅谈JVM调优策略 线程相关资料…

Spring-bean的生命周期-尾篇

上回说到阶段9&#xff0c;现在我们接着往下说 阶段10&#xff1a;所有单例bean初始化完成后阶段 所有单例bean实例化完成之后&#xff0c;spring会回调下面这个接口&#xff1a; package org.springframework.beans.factory;public interface SmartInitializingSingleton {…