调试和优化大型深度学习模型 - 4 混合精度训练中的关键组件 autocast 和 GradScaler

news/2024/9/25 22:25:53/

autocast__GradScaler_0">调试和优化大型深度学习模型 - 4 混合精度训练中的关键组件 autocastGradScaler

flyfish

PyTorch 版本 2.4.0

在混合精度训练中,autocastGradScaler 通常是一起使用的。autocast 提供了操作的半精度计算,而 GradScaler 通过缩放损失来防止可能发生的梯度下溢。结合使用它们,可以同时提高计算效率和数值稳定性

autocast___10">1. autocast - 自动混合精度

autocast 是 PyTorch 提供的一个上下文管理器,用于在模型的前向传播过程中自动选择合适的浮点数精度(FP16 或 FP32)。通过使用 autocast,你可以让模型在计算过程中自动将部分操作转换为半精度(FP16),从而加快计算速度并减少显存占用,同时保持数值精度较低的操作在 FP16 下执行。

使用 autocast 在混合精度训练中进行模型的前向传播例子

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义一个简单的神经网络
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型和优化器
model = SimpleModel().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# 创建输入数据和目标数据
inputs = torch.randn(10, 100).to(device)
targets = torch.randn(10, 10).to(device)# 使用 autocast 进行前向传播
with autocast(device_type='cuda'):outputs = model(inputs)loss = criterion(outputs, targets)# 继续进行后向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()print("autocast 示例运行成功,损失值:", loss.item())

输出
autocast 示例运行成功,损失值: 0.8400946855545044

GradScaler___59">2. GradScaler - 梯度缩放器

GradScaler 是 PyTorch 提供的用于混合精度训练中的梯度缩放工具。由于在 FP16 下计算梯度时可能会遇到数值下溢的问题(即梯度值过小,导致在反向传播时梯度被削减为 0),GradScaler 通过在反向传播之前将损失值缩放一个大数,从而避免梯度下溢。之后,GradScaler 会反过来缩放梯度,确保它们回到正常范围。
结合 GradScaler 使用混合精度训练,特别是梯度缩放以防止梯度下溢例子

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义一个简单的神经网络
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型、优化器和 GradScaler
model = SimpleModel().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
scaler = GradScaler()# 创建输入数据和目标数据
inputs = torch.randn(10, 100).to(device)
targets = torch.randn(10, 10).to(device)# 使用 autocast 进行前向传播,并结合 GradScaler 进行梯度缩放
with autocast(device_type='cuda'):outputs = model(inputs)loss = criterion(outputs, targets)# 梯度缩放与反向传播
optimizer.zero_grad()
scaler.scale(loss).backward()# 更新模型参数
scaler.step(optimizer)
scaler.update()print("GradScaler 示例运行成功,损失值:", loss.item())

输出
GradScaler 示例运行成功,损失值: 1.0215777158737183

要注意什么

梯度下溢:
混合精度训练中的一个常见问题是梯度下溢。由于 FP16 精度较低,在反向传播过程中,梯度值可能会变得非常小甚至为零。使用 GradScaler 可以有效地防止这个问题。

损失函数的稳定性:
一些损失函数在 FP16 下可能表现出数值不稳定的情况。在这种情况下,考虑使用 FP32 来计算损失或将关键操作保持在 FP32。

模型中的 BatchNorm 和其他层:
BatchNorm 和其他类似的层可能对精度敏感。PyTorch 通常会在 autocast 上下文中将这些层保持在 FP32,以防止数值不稳定。如果你发现训练不稳定,检查这些层的数值精度可能会有帮助。

硬件支持:
混合精度训练需要硬件的支持,例如 NVIDIA 的 Tensor Cores(在 Volta 及更高版本的 GPU 上提供)。确保你的 GPU 支持 FP16 运算以获得性能提升。

检查数值稳定性:
即使使用了 autocastGradScaler,也要监控训练过程中的数值稳定性,特别是损失值和梯度的变化。如果遇到不稳定情况,可以调整 GradScaler 的初始化参数或关闭某些不适合 FP16 的操作。

性能调优:
虽然混合精度训练通常会提高性能,但实际效果取决于模型的结构和硬件。可以通过实验调整 autocastGradScaler 的使用,找到最优的配置。


http://www.ppmy.cn/news/1511976.html

相关文章

全面解析Gerapy分布式部署:从环境搭建到定时任务,避开Crawlab的坑

Gerapy分布式部署 搭建远程服务器的环境 装好带docker服务的系统 Docker:容器可生成镜像,也可拉去镜像生成容器 示例:将一个环境打包上传到云端(远程服务器),其他8个服务器需要这个环境直接向云端拉取镜像生成容器,进而使用该环境,比如有MYS…

ActiveMQ、RabbitMQ、Kafka、RocketMQ在消息回溯、消息堆积+持久化、消息追踪、消息过滤的区别

ActiveMQ、RabbitMQ、Kafka、RocketMQ在消息回溯、消息堆积持久化、消息追踪、消息过滤等方面各有其独特的特点和优势。以下是这四个方面的详细比较: 1. 消息回溯 ActiveMQ:支持消息回溯功能。ActiveMQ可以将消息持久化到磁盘上,因此当需要…

运维学习————Redis在Linux(Centos7)单机部署和集群部署

目录 一、单机部署 1、软件准备 2、安装配置 3、启动Redis 二、Redis集群 2.1、主从模式 2.1.1、作用 2.1.2、规划图 2.1.3、具体配置 准备工作 主从配置 启动测试 2.1.4、主从复制原理 主从全量复制 主从增量同步(slave重启或后期数据变化) 2.1.5、缺点 2.2、哨兵…

PostgreSQL-04-入门篇-连接多张表

文章目录 1. 连接设置样例表PostgreSQL 左连接PostgreSQL 右连接PostgreSQL 全外连接 2. 表别名简介表别名的实际应用1) 对长表名使用表别名,使查询更具可读性2) 在连接子句中使用表别名3) 在自连接中使用表别名 3. INNER JOIN 内连接简介PostgreSQL INNER JOIN 示例…

day04-Maven入门

Maven 课程内容 初识MavenMaven概述 Maven模型介绍Maven仓库介绍Maven安装与配置IDEA集成Maven依赖管理 01. Maven课程介绍 1.1 课程安排 学习完前端Web开发技术后,我们即将开始学习后端Web开发技术。做为一名Java开发工程师,后端Web开发技术是我们学…

OpenLayers 使用高德地图并绘制一些线,并用Android原生触发

这是一份OpenLayers使用高德地图并绘制一些线代码&#xff0c;这高德来源好像不太正规建议自己去开发者平台逛逛。代码都有注释我就不过多介绍了。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name&…

libsvm数据格式及制作

libsvm的数据格式&#xff1a; [label][index1]:[value1] [index2]:[value2] … 其中&#xff0c;label 目标值&#xff0c;就是说class属于哪一类&#xff0c;就是你要分的类别&#xff0c;通常是一些证书&#xff1b; index 是有顺序的索引&#xff0c;通常是连续的整数。就是…

Mysql查询日志

Mysql查询日志 Mysql查询日志默认是关闭状态的。 mysql> show variables like %general_log%; --------------------------------------- | Variable_name | Value | --------------------------------------- | general_log | OFF …