【PyTorch】5-进阶训练技巧(损失函数、学习率、模型微调、半精度训练、数据增强、超参数设置)

ops/2024/10/11 13:30:43/

PyTorch:5-进阶训练技巧

注:所有资料来源且归属于thorough-pytorch(https://datawhalechina.github.io/thorough-pytorch/),下文仅为学习记录

5.1:自定义损失函数

PyTorch在torch.nn模块提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss

非官方提供的Loss,比如:DiceLoss,HuberLoss,SobolevLoss

5.1.1:以函数方式定义

损失函数仅是一个函数

def my_loss(output, target):loss = torch.mean((output - target)**2)return loss

5.1.2:以类方式定义

如果看每一个损失函数的继承关系,可以发现Loss函数部分继承自_loss,部分继承自_WeightedLoss,而_WeightedLoss继承自_loss _loss继承自 nn.Module

损失函数类需要继承自nn.Module类。

【案例:DiceLoss】

应用:分割
D i c e L o s s = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ DiceLoss=\frac{2|X∩Y|}{|X|+|Y|} DiceLoss=X+Y2∣XY
实现代码:

class DiceLoss(nn.Module):def __init__(self,weight=None,size_average=True):super(DiceLoss,self).__init__()def forward(self,inputs,targets,smooth=1):inputs = F.sigmoid(inputs)       inputs = inputs.view(-1)targets = targets.view(-1)intersection = (inputs * targets).sum()                   dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  return 1 - dice# 使用方法    
criterion = DiceLoss()
loss = criterion(input,targets)

在自定义损失函数时,可能涉及到数学运算,最好全程使用PyTorch提供的张量计算接口,这样就不需要实现自动求导功能,并且可以直接调用cuda

【案例:IoULoss】

实现代码:

class IoULoss(nn.Module):def __init__(self, weight=None, size_average=True):super(IoULoss, self).__init__()def forward(self, inputs, targets, smooth=1):inputs = F.sigmoid(inputs)       inputs = inputs.view(-1)targets = targets.view(-1)intersection = (inputs * targets).sum()total = (inputs + targets).sum()union = total - intersection IoU = (intersection + smooth)/(union + smooth)return 1 - IoU

5.2:动态调整学习

学习率过小:极大降低收敛速度,增加训练时间

学习率过大:导致参数在最优解两侧来回振荡

scheduler——通过一个适当的学习率衰减策略来改善这种现象

5.2.1:官方scheduler

PyTorch已经在torch.optim.lr_scheduler封装好了一些动态调整学习率的方法

  • lr_scheduler.LambdaLR
  • lr_scheduler.MultiplicativeLR
  • lr_scheduler.StepLR
  • lr_scheduler.MultiStepLR
  • lr_scheduler.ExponentialLR
  • lr_scheduler.CosineAnnealingLR
  • lr_scheduler.ReduceLROnPlateau
  • lr_scheduler.CyclicLR
  • lr_scheduler.OneCycleLR
  • lr_scheduler.CosineAnnealingWarmRestarts
  • lr_scheduler.ConstantLR
  • lr_scheduler.LinearLR
  • lr_scheduler.PolynomialLR
  • lr_scheduler.ChainedScheduler
  • lr_scheduler.SequentialLR

这些scheduler都是继承自_LRScheduler类。

使用API:

# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):train(...)validate(...)optimizer.step()# 需要在优化器参数更新之后再动态调整学习
# scheduler的优化是在每一轮后面进行的
scheduler1.step() 
...
schedulern.step()

在使用官方给出的torch.optim.lr_scheduler时,需要将scheduler.step()放在optimizer.step()后面进行使用。

5.2.2:自定义scheduler

方法:自定义函数adjust_learning_rate来改变param_grouplr的值

假设需要学习率每30轮下降为原来的1/10,实现代码如下:

def adjust_learning_rate(optimizer, epoch):lr = args.lr * (0.1 ** (epoch // 30))for param_group in optimizer.param_groups:param_group['lr'] = lr

在训练的过程调用学习率修正函数来实现学习率的动态变化

optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):train(...)validate(...)adjust_learning_rate(optimizer,epoch)

5.3:模型微调——torchvision

迁移学习(transfer learning):将从源数据集学到的知识迁移到目标数据集上。

模型微调(finetune):先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,通过训练调整一下参数。

5.3.1:模型微调流程

  1. 在源数据集上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。从头训练输出层,而其余层的参数基于源模型的参数微调。

5.3.2:使用已有模型结构

实例化模型

import torchvision.models as models
resnet18 = models.resnet18()

传递参数

通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。

5.3.3:训练特定层

在默认情况下,参数的属性.requires_grad = True

如果正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变,就需要通过设置requires_grad = False冻结部分层

PyTorch官方提供的例程:

def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False

修改示例:resnet输出从1000转为4,仅改变最后一层的模型参数,不改变特征提取的模型参数。

先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改,修改后的全连接层的参数就是可计算梯度的。

实现过程:

import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

5.4:模型微调——timm

5.4.1:依赖安装

通过pip

pip install timm

通过源码编译

git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

5.4.2:查看预训练模型种类

通过timm.list_models()方法查看timm提供的预训练模型

import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)

查看特定模型的所有种类

timm.list_models()传入想查询的模型名称(模糊查询)

案例:查询densenet系列的所有模型

all_densnet_models = timm.list_models("*densenet*")
all_densnet_models"""
['densenet121','densenet121d','densenet161','densenet169','densenet201','densenet264','densenet264d_iabn','densenetblur121d','tv_densenet121']
"""

查看模型的具体参数

访问模型的default_cfg属性

model = timm.create_model('resnet34',num_classes=10,pretrained=True)
model.default_cfg"""
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth','num_classes': 1000,'input_size': (3, 224, 224),'pool_size': (7, 7),'crop_pct': 0.875,'interpolation': 'bilinear','mean': (0.485, 0.456, 0.406),'std': (0.229, 0.224, 0.225),'first_conv': 'conv1','classifier': 'fc','architecture': 'resnet34'}
"""

5.4.3:使用和修改预训练模型

得到预训练模型后,通过timm.create_model()的方法创建新模型。

通过传入参数pretrained=True,使用预训练模型。

import timm
import torchmodel = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape"""
torch.Size([1, 1000])
"""

查看某一层模型参数

model = timm.create_model('resnet34',pretrained=True)
list(dict(model.named_children())['conv1'].parameters())"""
[Parameter containing:tensor([[[[-2.9398e-02, -3.6421e-02, -2.8832e-02,  ..., -1.8349e-02,-6.9210e-03,  1.2127e-02],[-3.6199e-02, -6.0810e-02, -5.3891e-02,  ..., -4.2744e-02,-7.3169e-03, -1.1834e-02],...[ 8.4563e-03, -1.7099e-02, -1.2176e-03,  ...,  7.0081e-02,2.9756e-02, -4.1400e-03]]]], requires_grad=True)]
"""

修改模型

model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape"""
torch.Size([1, 10])
"""

改变输入通道数

model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
# 通过添加in_chans=1来改变
x = torch.randn(1,1,224,224)
output = model(x)

5.4.4:模型保存

timm库所创建的模型是torch.model的子类

可以直接使用torch库中内置的模型参数保存和加载的方法

torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

5.5:半精度训练

GPU的性能主要分为两部分:算力和显存。

前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算。

PyTorch默认的浮点数存储方式用的是torch.float32

绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,即使用torch.float16格式。

数位减了一半,因此被称为**“半精度”**。半精度能够减少显存占用,使显卡可以同时加载更多数据进行计算。

5.5.1:半精度训练的设置

使用autocast配置半精度训练,需要在下面三处加以设置:

【1】import package

from torch.cuda.amp import autocast

【2】模型

使用python的装饰器方法,用autocast装饰模型中的forward函数

@autocast()   
def forward(self, x):...return x

【3】训练过程

在将数据输入模型及其之后的部分放入“with autocast():“即可

 for x in train_loader:x = x.cuda()with autocast():output = model(x)...

半精度训练适用于数据本身的size比较大(3D图像、视频等)。

5.6:数据增强——imgaug

最简单的避免过拟合的方法是增加数据

imgaug封装了很多数据增强算法

5.6.1:imgaug简介

imgaug是计算机视觉任务中常用的一个数据增强的包,相比于torchvision.transforms,它提供了更多的数据增强方法。

pip安装:

pip install imgaug

conda安装:

conda config --add channels conda-forge
conda install imgaug

5.6.2:imgaug使用

imgaug仅提供了图像增强的一些方法,但是并未提供图像的IO操作,因此需要使用一些库来对图像进行导入。

建议使用imageio进行读入,如果使用的是opencv进行文件读取的时候,需要进行手动改变通道,将读取的BGR图像转换为RGB图像

PIL.Image进行读取时,因为读取的图片没有shape的属性,所以需要将读取到的img转换为np.array()的形式再进行处理。

【1】单张图片处理

读取图像:

import imageio
import imgaug as ia
%matplotlib inline# 图片的读取
img = imageio.imread("./Lenna.jpg")# 使用Image进行读取
# img = Image.open("./Lenna.jpg")
# image = np.array(img)
# ia.imshow(image)# 可视化图片
ia.imshow(img)

imgaug包含了许多从Augmenter继承的数据增强的操作

数据增强:以Affine为例

from imgaug import augmenters as iaa# 设置随机数种子
ia.seed(4)# 实例化方法
rotate = iaa.Affine(rotate=(-4,45))# 旋转图片
img_aug = rotate(image=img)
ia.imshow(img_aug)

做多种数据增强处理时,需要利用imgaug.augmenters.Sequential()来构造数据增强的pipeline,该方法与torchvison.transforms.Compose()相类似。

标准格式:

iaa.Sequential(children=None, # Augmenter集合random_order=False, # 是否对每个batch使用不同顺序的Augmenter listname=None,deterministic=False,random_state=None)

多种增强:

# 构建处理序列
aug_seq = iaa.Sequential([iaa.Affine(rotate=(-25,25)),iaa.AdditiveGaussianNoise(scale=(10,60)),iaa.Crop(percent=(0,0.2))
])
# 对图片进行处理,image不可以省略,也不能写成images
image_aug = aug_seq(image=img)
ia.imshow(image_aug)

【2】批次图片处理

可以将图形数据按照NHWC的形式,或者由列表组成的HWC的形式对批量的图像进行处理。

主要分为以下两部分,对批次的图片以同一种方式处理和对批次的图片进行分部分处理。

对批次的图片以同一种方式处理

案例1:

images = [img,img,img,img,]
images_aug = rotate(images=images)
ia.imshow(np.hstack(images_aug))

案例2:

aug_seq = iaa.Sequential([iaa.Affine(rotate=(-25, 25)),iaa.AdditiveGaussianNoise(scale=(10, 60)),iaa.Crop(percent=(0, 0.2))
])# 传入时需要指明是images参数
images_aug = aug_seq.augment_images(images = images)
ia.imshow(np.hstack(images_aug))

对批次的图片进行分部分处理

可以通过imgaug.augmenters.Sometimes()对batch中的一部分图片应用一部分Augmenters,剩下的图片应用另外的Augmenters。

标准格式:

iaa.Sometimes(p=0.5,  # 代表划分比例then_list=None,  # Augmenter集合。p概率的图片进行变换的Augmenters。else_list=None,  #1-p概率的图片会被进行变换的Augmenters。注意变换的图片应用的Augmenter只能是then_list或者else_list中的一个。name=None,deterministic=False,random_state=None)

【3】不同大小图片处理

# 构建pipline
seq = iaa.Sequential([iaa.CropAndPad(percent=(-0.2, 0.2), pad_mode="edge"),  # crop and pad imagesiaa.AddToHueAndSaturation((-60, 60)),  # change their coloriaa.ElasticTransformation(alpha=90, sigma=9),  # water-like effectiaa.Cutout()  # replace one squared area within the image by a constant intensity value
], random_order=True)# 加载不同大小的图片
images_different_sizes = [imageio.imread("https://upload.wikimedia.org/wikipedia/commons/e/ed/BRACHYLAGUS_IDAHOENSIS.jpg"),imageio.imread("https://upload.wikimedia.org/wikipedia/commons/c/c9/Southern_swamp_rabbit_baby.jpg"),imageio.imread("https://upload.wikimedia.org/wikipedia/commons/9/9f/Lower_Keys_marsh_rabbit.jpg")
]# 对图片进行增强
images_aug = seq(images=images_different_sizes)# 可视化结果
print("Image 0 (input shape: %s, output shape: %s)" % (images_different_sizes[0].shape, images_aug[0].shape))
ia.imshow(np.hstack([images_different_sizes[0], images_aug[0]]))print("Image 1 (input shape: %s, output shape: %s)" % (images_different_sizes[1].shape, images_aug[1].shape))
ia.imshow(np.hstack([images_different_sizes[1], images_aug[1]]))print("Image 2 (input shape: %s, output shape: %s)" % (images_different_sizes[2].shape, images_aug[2].shape))
ia.imshow(np.hstack([images_different_sizes[2], images_aug[2]]))

5.6.3:imgaug应用

import numpy as np
from imgaug import augmenters as iaa
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms# 构建pipline
tfs = transforms.Compose([iaa.Sequential([iaa.flip.Fliplr(p=0.5),iaa.flip.Flipud(p=0.5),iaa.GaussianBlur(sigma=(0.0, 0.1)),iaa.MultiplyBrightness(mul=(0.65, 1.35)),]).augment_image,# 不要忘记了使用ToTensor()transforms.ToTensor()
])# 自定义数据集
class CustomDataset(Dataset):def __init__(self, n_images, n_classes, transform=None):# 图片的读取,建议使用imageioself.images = np.random.randint(0, 255,(n_images, 224, 224, 3),dtype=np.uint8)self.targets = np.random.randn(n_images, n_classes)self.transform = transformdef __getitem__(self, item):image = self.images[item]target = self.targets[item]if self.transform:image = self.transform(image)return image, targetdef __len__(self):return len(self.images)def worker_init_fn(worker_id):imgaug.seed(np.random.get_state()[1][0] + worker_id)custom_ds = CustomDataset(n_images=50, n_classes=10, transform=tfs)
custom_dl = DataLoader(custom_ds, batch_size=64,num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

使用Linux远程服务器时,可使用不同的num_workers的数量。worker_init_fn()函数的作用:保证了使用的数据增强在num_workers>0时是对数据的增强是随机的。

其他数据增强库:Albumentations,Augmentor。

5.7:使用argparse调参

作用:解析输入的命令行参数再传入模型的超参数中。

5.7.1:简介

argsparse是python的命令行解析的标准模块,内置于python,不需要安装。

argparse的作用是将命令行传入的其他参数进行解析、保存和使用。

5.7.2:使用

三个步骤:

  • 创建ArgumentParser()对象
  • 调用add_argument()方法添加参数
  • 使用parse_args()解析参数
import argparse# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()# 添加参数
parser.add_argument('-o', '--output', action='store_true', help="shows output")
# action = `store_true` 会将output参数记录为Trueparser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3') 
# type 规定了参数的格式
# default 规定了默认值parser.add_argument('--batch_size', type=int, required=True, help='input batch size')  
# required=True,表示batch-size是必选# 使用parse_args()解析函数
args = parser.parse_args()if args.output:print("This is some output")print(f"learning rate:{args.lr} ")

在命令行使用python demo.py --lr 3e-4 --batch_size 32,可以看到以下的输出

This is some output
learning rate: 3e-4

argparse的参数主要可以分为可选参数和必选参数。

可选参数在未输入的情况下会设置为默认值。

必选参数在当给参数设置required =True后,就必须传入该参数,否则就会报错。

【输入参数的时候不使用–】会严格按照参数位置进行解析

import argparse# 位置参数
parser = argparse.ArgumentParser()parser.add_argument('name')
parser.add_argument('age')args = parser.parse_args()print(f'{args.name} is {args.age} years old')

终端输入和输出的结果

$ positional_arg.py Peter 23
Peter is 23 years old

5.7.3:高效修改超参数

将有关超参数的操作写在config.py,然后在train.py或者其他文件导入。

【config.py】

import argparse  def get_options(parser=argparse.ArgumentParser()):  parser.add_argument('--workers', type=int, default=0,  help='number of data loading workers, you had better put it '  '4 times of your gpu')  parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  parser.add_argument('--seed', type=int, default=118, help="random seed")  parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  parser.add_argument('--checkpoint_path',type=str,default='',  help='Path to load a previous trained model if not empty (default empty)')  parser.add_argument('--output',action='store_true',default=True,help="shows output")  opt = parser.parse_args()  if opt.output:  print(f'num_workers: {opt.workers}')  print(f'batch_size: {opt.batch_size}')  print(f'epochs (niters) : {opt.niter}')  print(f'learning rate : {opt.lr}')  print(f'manual_seed: {opt.seed}')  print(f'cuda enable: {opt.cuda}')  print(f'checkpoint_path: {opt.checkpoint_path}')  return opt  if __name__ == '__main__':  opt = get_options()

终端:

$ python config.pynum_workers: 0
batch_size: 4
epochs (niters) : 10
learning rate : 3e-05
manual_seed: 118
cuda enable: True
checkpoint_path:

【train.py】

import configopt = config.get_options()manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path# 随机数的设置,保证复现结果
def set_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)random.seed(seed)np.random.seed(seed)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True...if __name__ == '__main__':set_seed(manual_seed)for epoch in range(niters):train(model,lr,batch_size,num_workers,checkpoint_path)val(model,lr,batch_size,num_workers,checkpoint_path)

argparse提供了一种新的更加便捷的方式,而在一些大型的深度学习库中人们也会使用json、dict、yaml等文件格式去保存超参数进行训练。


http://www.ppmy.cn/ops/34543.html

相关文章

Case中default的综合结果

在使用case语句时,不完备的case语句会导致Vivado综合时推断出锁存器。下面通过实例来详细看看各种情况下的综合结果: 1.完备的case语句 下述的verilog对应的电路结构是一个8选一的多路复用器: module case_test(input [2:0]sel,input data…

公众号图片尺寸怎么调整?图片在线处理的方法介绍

平时我们接触到的图片文件有非常多的格式,而且收到的尺寸各不相同,这种时候就要我们修改图片尺寸大小了,这样做首先可以为我们节省存储空间,还可以让图片的加载速度变快,分享出去的图片也可以更快进行查看,…

53. 【Android教程】Socket 网络接口

Socket 网络接口 大家在学习计算机网络的时候一定学习过 TCP/IP 协议以及最经典的 OSI 七层结构,简单的回忆一下这 7 层结构: 从下到上依次是: 物理层数据链路层互联层网络层会话层表示层应用层 TCP/IP 协议对这 7 层了做一点精简&#xff…

Netty的一个入门小程序

文章目录 1. 服务端代码2. 客户端代码3. 执行流程分析 1. 服务端代码 Slf4j public class Server {public static void main(String[] args) {//启动器,负责组装netty组件,启动服务器new ServerBootstrap()//BoosEventLoop,WorkerEventLoop&…

Go 语言 ORM 框架之 xorm

1、xorm 1.1、xorm 简介 xorm 是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 特性 支持 struct 和数据库表之间的灵活映射,并支持自动同步事务支持同时支持原始SQL语句和ORM操作的混合执行使用连写来简化调用支持使用ID, In, Where, Limit,…

用 Go map 要注意这个细节,避免依赖他!

有的小伙伴没留意过 Go map 输出、遍历顺序,以为它是稳定的有序的,会在业务程序中直接依赖这个结果集顺序,结果栽了个大跟头,吃了线上 BUG。 有的小伙伴知道是无序的,但却不知道为什么,有的却理解错误? 今…

23种设计模式

一、创建型模式: 1.工厂方法模式:定义一个用于创建对象的接口,让子类决定实例化哪个类。 2.抽象工厂模式:提供一个创建一系列相关或相互依赖对象的接口,而无需指定具体类。 3.单例模式:确保一个类只有一个…

线下线上陪玩APP小程序H5搭建设计-源码交付,支持二开!

一、电竞陪玩系统APP的概念 电竞陪玩系统APP是一种专门为电子竞技玩家提供服务的平台。通过这个平台,玩家可以找到专业的电竞陪玩者,他们可以帮助玩家提升游戏技能,提供游戏策略建议,甚至陪伴玩家一起进行游戏。这种服务不仅可以提…