【Pytorch】理解自动混合精度训练

news/2024/10/22 11:04:51/

【Pytorch】理解自动混合精度训练

  • 混合精度概述
  • 实验对比

  更大的深度学习模型需要更多的计算能力和内存资源。一些新技术的提出,可以更快地训练深度神经网络。我们可以使用 FP16(半精度浮点数格式)来代替 FP32(全精度浮点数格式),研究人员发现串联使用它们是更好的选择。有的 GPU(例如 Paperspace 提供的 Ampere GPU)甚至可以利用较低级别的精度,例如 INT8。

  混合精度允许半精度训练,同时仍保留大部分单精度网络精度。术语“混合精度技术”是指该方法同时使用单精度和半精度表示。

  在使用 PyTorch 进行自动混合精度 (Amp) 训练的概述中,我们演示了该技术的工作原理,逐步介绍使用 Amp 的过程,并通过代码讨论 Amp 技术的应用。

混合精度概述

  在深度学习的世界里,使用 FP16 进行计算不仅能显著提升性能,还能节省内存。然而,这种方法也带来了两个主要问题:精度溢出和舍入误差。这两个问题是深度学习中 FP16 计算的关键挑战。

   精度溢出(Precision Overflow)

  在 FP16 格式下,由于位宽较小,可表示的数值范围远小于 FP32 或 FP64。这容易导致数值过大或过小而无法在 FP16 的表示范围内精确表示。在深度学习中,这可能引起梯度消失或梯度爆炸,因为一些小的梯度值可能变成零(下溢),而一些大的梯度值可能变得无限大(上溢)。这种溢出问题会严重影响模型训练的稳定性和最终性能。

  舍入误差(Rounding Error)

  FP16 由于其16位的表示限制,相比于 FP32 或 FP64,舍入误差更加明显。在深度学习中,每次计算的舍入误差会累积,尤其是在多层和复杂运算中。这可能导致模型输出与使用更高精度计算时存在显著差异。对于那些对精确度要求极高的应用(比如金融或医疗领域),这种误差可能造成不可接受的后果。

  为了缓解这些问题,混合精度训练方法在关键部分(如权重更新)使用 FP32 来保持精度,而在其他操作(如前向传播)中使用 FP16 来提高效率。混合精度训练中,我特别注意到了权重备份(Weight Backup)、损失放大(Loss Scaling)、精度累加(Precision Accumulated)这三种技术的重要性。

  权重备份(Weight Backup):

  在混合精度训练中,为了确保数值稳定性,模型的权重通常会在 FP16 和 FP32 两种格式下同时维护。权重备份是指保留 FP32 格式的权重副本,这样即使在大部分使用 FP16 格式的计算过程中出现数值不稳定现象,我们仍然能依靠 FP32 权重副本保持稳定和精确。这对于更新模型参数时的准确性至关重要。

  损失放大(Loss Scaling):

  在混合精度训练中,由于 FP16 的表示范围限制,梯度值可能太小而无法在 FP16 中准确表示,导致有效梯度变为零。损失放大是通过在计算梯度前将损失函数的值乘以一个较大的常数(放大因子),从而放大梯度值,使其在 FP16 范围内可表示且非零。在反向传播后,再将放大的梯度除以相同的放大因子,恢复原始比例,这样可以有效减少梯度下溢问题。

  精度累加(Precision Accumulation):

  精度累加是指在权重更新过程中,即使梯度计算是在 FP16 下完成的,但权重更新则在 FP32 精度下进行。这有助于减少舍入误差和累积误差,尤其在训练过程中涉及大量累积操作时。由于 FP32 提供更高的数值精度和更大的表示范围,可以更准确地累积小梯度值,避免更新权重时的数值不稳定性。

  综上所述,通过将这些技术相结合,混合精度训练能够有效地利用 FP16 带来的性能优势,同时最大限度地减少精度损失和计算不稳定性。

实验对比

  为了进一步验证这些技术的有效性,我设计了两个实验对比使用混合精度和传统的 FP32 两种方式进行训练。以下是我在这两个实验中使用的代码片段:

FP16与FP32混合训练代码:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timefrom torch.cuda.amp import GradScaler, autocastclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20000),nn.Linear(20000, 20000),# nn.Dropout(0.1),nn.Linear(20000, 200),# nn.LayerNorm(20),nn.Linear(200, 20),# nn.LayerNorm(20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.0001
iteration = 1000x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()scaler = GradScaler()start_time = time.time()
writer = SummaryWriter(comment='_FP16')for iter in range(iteration):with autocast():y_pred = model(x)loss = loss_function(y, y_pred.squeeze())scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()writer.add_scalar('loss', loss, iter)optimizer.zero_grad()if iter % 100 == 0:# 获取 GPU 的内存使用情况print("GPU Memory Allocated:", torch.cuda.memory_allocated())print("GPU Memory Cached:   ", torch.cuda.memory_reserved())print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp16.pth')

FP32训练代码:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timefrom torch.cuda.amp import GradScaler, autocastclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20000),nn.Linear(20000, 20000),# nn.Dropout(0.1),nn.Linear(20000, 200),# nn.LayerNorm(20),nn.Linear(200, 20),# nn.LayerNorm(20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.0001
iteration = 1000x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()scaler = GradScaler()start_time = time.time()
writer = SummaryWriter(comment='_FP32')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()optimizer.step()writer.add_scalar('loss', loss, iter)optimizer.zero_grad()if iter % 100 == 0:# 获取 GPU 的内存使用情况print("GPU Memory Allocated:", torch.cuda.memory_allocated())print("GPU Memory Cached:   ", torch.cuda.memory_reserved())print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp32.pth')

  最终两者的效果如下:

实验名占用GPU消耗时间
FP16与FP32混合训练4867271680 bytes78.73 s
FP32训练4867274752 bytes140.18 s

  实验分析如下:

  1、在内存占用方面,两种训练方法几乎相同。这可能是因为模型结构和数据集大小相同,所以内存占用没有显著差异。然而,通常情况下,FP16训练应该占用更少的内存,因为它使用的是半精度浮点数。

  2、在训练时间方面,混合精度训练明显快于纯FP32训练。这是因为FP16训练可以加快计算速度并降低内存需求,从而允许模型更快地运行。混合精度训练结合了FP16的高效率和FP32的数值稳定性,提供了一个平衡的解决方案。

在这里插入图片描述

  在提供的损失图中,我们可以看到蓝色曲线代表使用全精度(FP32)训练的模型损失,而橙色曲线代表使用混合精度(FP16与FP32)训练的模型损失。以下是对两种训练方法损失曲线的进一步分析:

  · 在下降到一定程度后,两条曲线都达到平稳状态,这表明模型已基本收敛。在这个平稳阶段,损失变化不大,说明模型在训练集上的表现已经稳定。

  · 没有明显的过拟合迹象,因为损失曲线没有再次上升的趋势。

  · 混合精度训练似乎在时间效率略优于全精度训练,尽管最终损失值的差异不大,但在追求快速迭代和高效训练的情况下,选择混合精度训练会更有优势;而在收敛速度上,混合精度训练在140epoch时收敛完毕,而全精度训练早在100epoch就收敛完,说明全精度训练收敛较于混合精度训练,在step层面更快,而它在时间层面更慢。


http://www.ppmy.cn/news/1259757.html

相关文章

实战中使用的策略模式,使用@ConditionalOnProperty实现根据环境注册不同的bean

场景复现 举个例子,针对不同的设备的内存的不同加载一些资源的时候需要采取不同的策略,比如,在内存比较大的设备,可以一次性加载,繁殖需要使用懒加载,这个时候我们就可以采用配置文件配置中心去控制了 Cond…

吴恩达《机器学习》11-1-11-2:首先要做什么、误差分析

一、首先要做什么 选择特征向量的关键决策 以垃圾邮件分类器算法为例,首先需要决定如何选择和表达特征向量 𝑥。视频提到的一个示例是构建一个由 100 个最常出现在垃圾邮件中的词构成的列表,根据这些词是否在邮件中出现来创建特征向量&…

科技论文中的Assumption、Remark、Property、Lemma、Theorem、Proof含义

一、背景 学控制、数学、自动化专业的学生在阅读论文时,经常会看到Assumption、Remark、Property、Lemma、Theorem、Proof等单词,对于初学者可能不太清楚他们之间的区别,因此这里做一下详细的说明。 以机器人领域的论文为例。 论文题目&…

【从0配置JAVA项目相关环境1】jdk + VSCode运行java + mysql + Navicat + 数据库本地化 + 启动java项目

从0配置JAVA项目相关环境 写在最前面一、安装Java的jdk环境1. 下载jdk2. 配置jdk3. 配置环境变量 二、在vscode中配置java运行环境1. 下载VSCode2. 下载并运行「Java Extension Pack」 三、安装mysql1.官网下载MySQL2.开始安装如果没有跳过安装成功 3.配置MySQL Server4.环境变…

STM32串口接收数据包(自定义帧头帧尾)

1、基本概述 本实验基于stm32c8t6单片机,串口作为基础且重要的外设,具有广泛的应用。本文主要理解串口数据包的发送与接收是如何实现的,重要的是理解程序的实现思路。 2、关键程序 定义好需要用到的变量: uint8_t rxd_buf[4];//…

1.1 计算机和编程语言

计算机与编程语言的用处 计算机与大家的生活息息相关,例如银行的ATM机就是计算机、日常使用的手机等。大家大部分情况都是使用现有的软件,只有在特定场景、特定需求的环境下才会编写软件 课程目的 计算机是怎么工作的计算机擅长干什么,计算…

[传智杯 #4 初赛] 萝卜数据库

题目描述 花栗鼠很喜欢偷吃生产队的大萝卜,因此花栗鼠科技大学正在研究一种新型的数据库,叫做萝卜数据库。 具体来说,它支持 k(1≤k≤100) 个字段,每个字段名都是整数,里面存储的数值也都是整数。 现在你支持如下操…

*p++和(*p)++的区别

*p和(*p)的区别 *和是同优先级操作符,且都是从右至左结合的 ∗ * ∗p:取p所指单元的值,p指向下一单元,即p自加1,然后p指向下一个地址。和 (p)意思一样 (*p):()的优先级比和都高,所以作用在()内…