PyTorch中的学习率预热(warmup)

embedded/2024/11/14 13:14:24/

      PyTorch提供了学习率调度器(learning rate schedulers),用于在训练过程中实现各种调整学习率的方法。实现在torch.optim.lr_scheduler.py中,根据epoch数调整学习率。大多数学习率调度器可以称为背对背(back-to-back),也称为链式调度器,结果是每个调度器都一个接一个地应用于前一个调度器获得的学习率。学习率调度器应在优化器更新(optimizer.step())后应用

      warmup是 ResNet 论文中提出的方法:We further explore n = 18 that leads to a 110-layer ResNet. In this case, we find that the initial learning rate of 0.1 is slightly too large to start converging. so we use 0.01 to warm up the training until the training error is below 80% (about 400 iterations), and then go back to 0.1 and continue training. The rest of the learning schedule is as done previously.

      warmup是一种学习率优化方法。使用warmup可以在训练初期使用较小的学习率进行稳定的模型训练,然后逐渐增加学习率以提高收敛速度和模型性能。有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳。有助于保持模型深层的稳定性。

      由于刚开始训练时,模型的权重是随机初始化的,loss比较大,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择warmup的方式,可以使得开始训练的几个epoch或者一些step内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,有助于保持模型深层的稳定性,使得模型收敛速度变得更快,模型效果更佳。最终稳定阶段降低学习率更容易找到局部最优,可以增加batch size,这样更稳定。

      如果收敛太快,很快就在训练集上过拟合了,可以降低学习率,如果训练过慢或不收敛,则可以增加学习率。

      一般学习率设置:上升----平稳----下降。

      :以上内容来自于网络整理。

      PyTorch中的学习率预热方法:假设optimizer中设置的学习率为lr

      (1).ConstantLR(optimizer, factor, total_iters):前total_iters次,学习率为lr*factor,以后学习率变为lr。

      (2).LinearLR(optimizer, start_factor, end_factor, total_iters):前total_iters次,学习率从lr*start_factor逐次增加,以后学习率变为lr*end_factor。

      (3).LambdaLR(optimizer, lr_lambda):lr_lambda为lambda函数,如为以下:则学习率为lr*0.95的epoch次方。

lr_lambda = lambda epoch: 0.95 ** epoch

      (4).ExponentialLR(optimizer, gamma):学习率为lr*gamma的epoch次方。

      (5).StepLR(optimizer, step_size, gamma):学习率为lr*gamma的(当前opoch/step_size)次方。

      (6).MultiStepLR(optimizer, milestones, gamma):milestones为列表,如为[5,10,50,200],则epoch<5时,学习率为lr;epoch在[5,10)之间时为lr*gamma的1次方;epoch在[10,50)之间时为lr*gamma的2次方,依次类推。与StepLR相比,它允许学习率在不同的时间点以不同的步长衰减。

      (7).CosineAnnealingLR(optimizer, T_max, eta_min):学习率按照余弦函数进行周期性调整,每个周期结束时重置为初始学习率。T_max为周期内的最大迭代次数,eta_min为最小学习率。

      (8).ReduceLROnPlateau(optimizer, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps):当验证集上的loss停止改进时,自动降低学习率。这种方法不需要预先定义学习率衰减的时间表,而是根据模型的表现动态调整。

      (9).CyclicLR(optimizer, base_lr, max_lr, ...):学习率已恒定频率在给定的两个边界之间循环。

      (10).PolynomialLR(optimizer, total_iters, power):使用多项式函数衰减学习率。当epoch大于total_iters时,后面的学习率都为0。

      测试代码如下所示:

import colorama
import argparse
import time
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.models as modelsdef parse_args():parser = argparse.ArgumentParser(description="learning rate warm up")parser.add_argument("--epochs", required=True, type=int, help="number of training")parser.add_argument("--dataset_path", required=True, type=str, help="source dataset path")parser.add_argument("--model_name", required=True, type=str, help="the model generated during training or the model loaded during prediction")parser.add_argument("--pretrained_model", type=str, default="", help="pretrained model loaded during training")parser.add_argument("--batch_size", type=int, default=2, help="specify the batch size")args = parser.parse_args()return argsdef load_dataset(dataset_path, batch_size):mean = (0.53087615, 0.23997033, 0.45703197)std = (0.29807151489753686, 0.3128615049442739, 0.15151863355831655)transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])train_dataset = ImageFolder(root=dataset_path+"/train", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)val_dataset = ImageFolder(root=dataset_path+"/val", transform=transform)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"return len(train_dataset.class_to_idx), len(train_dataset), len(val_dataset), train_loader, val_loaderdef train(model, train_loader, device, optimizer, criterion, train_loss, train_acc):model.train() # set to training modefor _, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() # clean existing gradientsoutputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossloss.backward() # backpropagate the gradientsoptimizer.step() # update the parameterstrain_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute the accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floattrain_acc += acc.item() * inputs.size(0) # compute the total accuracyreturn train_loss, train_accdef validate(model, val_loader, device, criterion, val_loss, val_acc):model.eval() # set to evaluation modewith torch.no_grad():for _, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossval_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute validation accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floatval_acc += acc.item() * inputs.size(0) # compute the total accuracyreturn val_loss, val_accdef training(epochs, dataset_path, model_name, pretrained_model, batch_size):classes_num, train_dataset_num, val_dataset_num, train_loader, val_loader = load_dataset(dataset_path, batch_size)model = models.ResNet(block=models.resnet.BasicBlock, layers=[2,2,2,2], num_classes=classes_num) # ResNet18if pretrained_model != "":model.load_state_dict(torch.load(pretrained_model))optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99), eps=1e-7) # set the optimizerscheduler = optim.lr_scheduler.ConstantLR(optimizer, factor=0.2, total_iters=10)# scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.2, end_factor=0.8, total_iters=5)# assert len(optimizer.param_groups) == 1, f"optimizer.param_groups's length must be equal to 1: {len(optimizer.param_groups)}"# lr_lambda = lambda epoch: 0.95 ** epoch# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)# scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15], gamma=0.2)# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.05)# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")# scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.05)# scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=1.)print(f"epoch: 0/{epochs}: learning rate: {scheduler.get_last_lr()}")criterion = nn.CrossEntropyLoss() # set the lossdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)highest_accuracy = 0.minimum_loss = 100.for epoch in range(epochs):epoch_start = time.time()train_loss = 0.0 # losstrain_acc = 0.0 # accuracyval_loss = 0.0val_acc = 0.0train_loss, train_acc = train(model, train_loader, device, optimizer, criterion, train_loss, train_acc)val_loss, val_acc = validate(model, val_loader, device, criterion, val_loss, val_acc)# scheduler.step(val_loss) # update lr, ReduceLROnPlateauscheduler.step() # update lravg_train_loss = train_loss / train_dataset_num # average training lossavg_train_acc = train_acc / train_dataset_num # average training accuracyavg_val_loss = val_loss / val_dataset_num # average validation lossavg_val_acc = val_acc / val_dataset_num # average validation accuracyepoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; learning rate: {scheduler.get_last_lr()}, train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(model.state_dict(), model_name)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.00001 and avg_val_acc > 0.9999:print(colorama.Fore.YELLOW + "stop training early")torch.save(model.state_dict(), model_name)breakif __name__ == "__main__":# python test_learning_rate_warmup.py --epochs 1000 --dataset_path datasets/melon_new_classify --pretrained_model pretrained.pth --model_name best.pthcolorama.init(autoreset=True)args = parse_args()training(args.epochs, args.dataset_path, args.model_name, args.pretrained_model, args.batch_size)print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下所示:

      GitHub:https://github.com/fengbingchun/NN_Test


http://www.ppmy.cn/embedded/114159.html

相关文章

spark 读es

idea maven 依赖 <dependency> <groupId>org.elasticsearch</groupId> <artifactId>elasticsearch-hadoop</artifactId> <version>7.11.1</version> </dependency> <dependency> <groupId>org.elasticsearch.cl…

Python和C++气候模型算法模型气候学模拟和统计学数据可视化及指标评估

&#x1f3af;要点 贝叶斯推理气候模型辐射对流及干湿能量平衡模型时间空间气象变化预测模型评估统计指标气象预测数据变换天气和气象变化长短期影响预估降低气候信息尺度评估算法气象行为模拟&#xff1a;碳循环、辐射强迫和温度响应温室气体排放碳循环温室诱导气候变化评估气…

企业微信应用消息收发实施记录

一、前置配置 1.1 进入我的企业页面&#xff0c;记录下企业ID。 1.2 创建企微应用&#xff0c;记录下应用的 AgentId 和 Secret。 1.3 设置应用的企业可信IP&#xff0c;将服务器公网 IP 填入即可。 1.4 设置应用接收消息API 填入服务器 API 地址&#xff0c;并记录下随机获取…

oracle查询历史操作记录

示例&#xff1a; SELECTsubstr( a.sql_text, 1, 256 ) "SQL Text",( SELECT b1.username FROM all_users b1 WHERE b1.user_id a.parsing_user_id ) "Parsing User Name",a.users_executing "Users Executing",a.rows_processed "Rows P…

激光干涉仪的系统校准时需要注意的关键步骤

在进行激光干涉仪的系统校准时&#xff0c;以下是一些关键步骤和注意事项&#xff1a; 环境条件控制&#xff1a;确保测量环境的稳定性&#xff0c;控制温度、湿度和气压的变化&#xff0c;因为这些因素都可能影响激光的传播和干涉图的形成。预热&#xff1a;在开始校准前&…

Golang | Leetcode Golang题解之第419题棋盘上的战舰

题目&#xff1a; 题解&#xff1a; func countBattleships(board [][]byte) (ans int) {for i, row : range board {for j, ch : range row {if ch X && !(i > 0 && board[i-1][j] X || j > 0 && board[i][j-1] X) {ans}}}return }

Kubernetes 安装网络插件flannel报错Init:ImagePullBackOff,flannel下载镜像报错问题解决

Kubernetes1.28安装网络插件flannel&#xff0c;报错Init:ImagePullBackOff &#xff0c;flannel安装下载镜像失败 问题 1.安装flannel kubectl apply -f https://github.com/flannel-io/flannel/releases/latest/download/kube-flannel.yml 2.flannel报错信息 执行查看安装…

828华为云征文 | 使用Flexus云服务器X实例部署GLPI资产管理系统

828华为云征文 | 使用Flexus云服务器X实例部署GLPI资产管理系统 1. 部署环境说明2. 部署基础环境2.1. 操作系统基本配置2.2. 部署Nginx2.3. 部署MySQL2.4. 部署PHP 3. 部署GLPI资产管理系统 1. 部署环境说明 本次环境选择使用华为云Flexus云服务器X实例&#xff0c;因为其具有高…