25.梯度消失和梯度爆炸

news/2024/10/19 23:37:28/

深度学习中的梯度消失与梯度爆炸:定义、原因、解决办法与残差网络

一、引言

在深度学习的训练过程中,梯度消失(Gradient Vanishing)和梯度爆炸(Gradient Exploding)是两个常见且棘手的问题。它们严重阻碍了深层神经网络的训练效率和效果。本文将深入探讨这两个问题的定义、原因、解决办法,并介绍残差网络(ResNet)如何解决这些问题。

二、梯度消失与梯度爆炸的定义

梯度消失

梯度消失指的是在训练深层神经网络时,由于链式法则的连乘效应,当网络层数过深时,梯度在反向传播过程中会逐渐减小到接近于0,导致深层网络的权重无法得到有效的更新。

梯度爆炸

梯度爆炸则相反,指的是在训练深层神经网络时,梯度在反向传播过程中逐渐增大,甚至以指数级速度增长,导致权重更新过大,破坏网络的稳定性。

三、梯度消失与梯度爆炸的原因

链式法则

在反向传播过程中,梯度是通过链式法则逐层传递的。如果网络层数过深,且激活函数的梯度小于1(如Sigmoid函数),那么在多层连续相乘后,梯度会逐渐减小到接近于0,导致梯度消失;而如果梯度大于1,则会导致梯度爆炸。

初始化权重

网络权重的初始化方式也会影响梯度的传播。如果初始权重过大,可能导致梯度在反向传播过程中迅速增大,引起梯度爆炸;如果初始权重过小,则可能导致梯度在传播过程中逐渐减小,引起梯度消失。

四、梯度消失与梯度爆炸的解决办法

1.预训练与微调(Pre-training and Fine-tuning):早期的一种方法,先在一个大型数据集上进行预训练,然后在特定任务上进行微调。这种方法可以减轻梯度消失和爆炸的问题,但现在已经较少使用。

2.梯度裁剪

梯度裁剪是一种直接控制梯度大小的方法。在反向传播过程中,如果梯度的范数超过某个阈值,就将其截断为阈值大小。这样可以有效防止梯度爆炸。

3.使用ReLU激活函数

ReLU(Rectified Linear Unit)激活函数在输入大于0时梯度为1,不会出现梯度消失的问题;而在输入小于0时梯度为0,有助于稀疏化网络。因此,使用ReLU激活函数可以有效缓解梯度消失和梯度爆炸的问题。

4.改进版的ReLU激活函数:为了解决ReLU的缺点,研究者提出了多种改进版的ReLU函数,如Leaky ReLU、Parametric ReLU(PReLU)、Exponential Linear Unit(ELU)等。

5.Batch Normalization

Batch Normalization是一种有效的正则化方法,它通过规范化每一层的输入来加速网络训练。在训练过程中,Batch Normalization会对每一层的输入进行标准化处理,使其具有均值为0、方差为1的分布。这样可以减小梯度对初始权重的依赖,从而缓解梯度消失和梯度爆炸的问题。

6.残差网络(ResNet)

残差网络通过引入残差连接(shortcut connections)来解决梯度消失和梯度爆炸的问题。残差连接允许梯度在反向传播时绕过某些层直接传播到较浅的层,从而有效避免了梯度消失的问题。同时,由于残差连接的存在,网络在训练时可以更容易地学习到恒等映射(identity mapping),这有助于保持网络的稳定性并防止梯度爆炸。

 

五、残差网络(ResNet)的实现

基于残差网络(ResNet)的实现,我们可以进一步探讨其结构、特点以及在实际应用中的优势。以下是对ResNet实现的详细解析:

1. 残差块(Residual Block)

残差块是ResNet的核心组件,它解决了随着网络深度增加出现的性能下降(也称为退化问题)的问题。残差块的设计基于恒等映射(identity mapping)的思想,允许网络在必要时跳过一些层,从而更直接地传播梯度。

残差块的基本结构如下:

  • 包含两个或多个卷积层(以及可能的批量归一化层和激活函数层)。
  • 引入了一个跨层的连接(即shortcut或skip connection),将输入直接连接到输出。

这样的结构可以表示为:

H(x)=F(x)+x

其中,x 是输入,F(x) 是残差函数(即卷积层等结构所学习的映射),H(x) 是最终的输出。

2. 残差网络的构建

ResNet由多个残差块堆叠而成,形成一个深层的神经网络结构。根据具体的任务和网络规模,可以设计不同深度和宽度的ResNet。

在构建ResNet时,需要考虑以下几点:

  • 深度:通常,增加网络深度可以提高性能,但也会增加计算量和过拟合的风险。因此,需要根据任务和数据集的大小选择合适的深度。
  • 宽度:每个残差块的宽度(即卷积层的通道数)也会影响网络的性能。较宽的残差块可以提取更多的特征,但也会增加计算量。
  • 残差块的类型:根据残差块中卷积层的数量和连接方式,可以设计不同类型的残差块,如基本的残差块(包含两个卷积层)和瓶颈残差块(包含三个卷积层,其中第一个和最后一个卷积层的通道数较少,以减少计算量)。

3. 实现细节

在实现ResNet时,需要注意以下细节:

  • 初始化:使用合适的权重初始化方法,如He初始化,可以加速训练并提高模型的性能。
  • 批量归一化:在每个卷积层后添加批量归一化层,可以加速训练并缓解过拟合问题。
  • 激活函数:使用ReLU或类似的激活函数,以增加模型的非线性表达能力。
  • 下采样:在需要减小特征图尺寸时,可以使用步长为2的卷积层或池化层进行下采样。同时,为了确保残差连接能够匹配输入和输出的尺寸,可以在shortcut连接中添加一个额外的卷积层或池化层进行下采样。

4. 应用与优势

ResNet在多个领域都取得了显著的性能提升,特别是在图像分类、目标检测等任务中。其优势主要体现在以下几个方面:

  • 解决了深度神经网络中的退化问题,使得训练更深层的网络成为可能。
  • 通过引入残差连接,缓解了梯度消失和梯度爆炸的问题,提高了模型的训练效率和稳定性。
  • 具有较强的特征提取能力,可以学习到更丰富的层次化特征表示。
  • 具有良好的泛化能力,可以在不同的数据集和任务上取得较好的性能。

总之,ResNet通过引入残差连接的思想,成功解决了深度神经网络中的退化问题,并在多个领域取得了显著的性能提升。其实现细节和应用优势也为我们设计更优秀的深度学习模型提供了有益的参考。

import torch  
import torch.nn as nn  class BasicBlock(nn.Module):  expansion = 1  def __init__(self, in_channels, out_channels, stride=1, downsample=None):  super(BasicBlock, self).__init__()  # 第一个卷积层,不改变通道数  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)  self.bn1 = nn.BatchNorm2d(out_channels)  self.relu = nn.ReLU(inplace=True)  # 第二个卷积层,不改变通道数和步长  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)  self.bn2 = nn.BatchNorm2d(out_channels)  # 如果需要下采样,则使用1x1卷积改变通道数并降低空间分辨率  self.downsample = downsample  def forward(self, x):  residual = x  # 经过两个卷积层  out = self.conv1(x)  out = self.bn1(out)  out = self.relu(out)  out = self.conv2(out)  out = self.bn2(out)  # 如果需要进行下采样,则对输入x进行同样的操作  if self.downsample is not None:  residual = self.downsample(x)  # 将残差连接添加到输出上  out += residual  out = self.relu(out)  return out  class ResNet(nn.Module):  def __init__(self, block, layers, num_classes=10):  super(ResNet, self).__init__()  # 输入为3通道的图像,大小为224x224  self.in_channels = 64  # 初始的卷积层  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  self.bn1 = nn.BatchNorm2d(64)  self.relu = nn.ReLU(inplace=True)  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 构建残差块  self.layer1 = self._make_layer(block, 64, layers[0])  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 全连接层进行分类  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  self.fc = nn.Linear(512 * block.expansion, num_classes)  def _make_layer(self, block, out_channels, blocks, stride=1):  downsample = None  if stride != 1 or self.in_channels != out_channels * block.expansion:  downsample = nn.Sequential(  nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),  nn.BatchNorm2d(out_channels * block.expansion)  )  layers = []  layers.append(block(self.in_channels, out_channels, stride, downsample))  self.in_channels = out_channels * block.expansion  for _ in range(1, blocks):  layers.append(block(self.in_channels, out_channels))  return nn.Sequential(*layers)  def forward(self, x):  out = self.conv1(x)  out = self.bn1(out)  out = self.relu(out)  out = self.maxpool(out)  # 传递输入到各个残差层  out = self.layer1(out)  out = self.layer2(out)  out = self.layer3(out)  out = self.layer4(out)  # 对输出进行全局平均池化,展平  out = self.avgpool(out)  out = torch.flatten(out, 1)  # 全连接层进行分类  out = self.fc(out)  return out  # 示例:定义一个ResNet18  
def resnet18(num_classes=1000):  return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)  # 实例化ResNet18模型  
model = resnet18(num_classes=10)  # 假设有10个类别  # 打印模型结构  
print(model)  # 如果你有数据的话,可以继续编写代码进行训练  
# 例如,加载数据集、定义损失函数、优化器、训练循环等  # 示例:定义损失函数和优化器(这里只是示例,你需要根据实际情况设置)  
# criterion = nn.CrossEntropyLoss()  
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 注意:上面的代码只是一个示例,并没有实际的数据加载和训练过程。  
# 在实际使用中,你需要添加数据加载、训练循环、验证等步骤来完整实现ResNet的训练。

以上代码定义了一个简单的ResNet模型,并给出了一个实例化ResNet18的示例。ResNet18包含4个残差层,每个层包含2个BasicBlock。你可以根据实际需求调整层数和每层的Block数量来构建不同深度的ResNet模型。同时,你还需要定义损失函数和优化器,并编写数据加载和训练循环的代码来完成模型的训练过程。


http://www.ppmy.cn/news/1469663.html

相关文章

企业级开源项目,云缓存解决方案:CacheCloud

CacheCloud:简化缓存管理,释放数据潜力- 精选真开源,释放新价值。 概览 CacheCloud是由搜狐视频团队开发的一款开源的Redis缓存云平台,支持Redis多种架构(Standalone、Sentinel、Cluster)高效管理、有效降低大规模redis运维成本&…

非接触式装配监控技术实现对装配工作站操作的实时动作识别和定位

当今快速发展的工业领域,智能制造作为第四次工业革命的核心,正引领着生产方式的革新。智能制造的关键在于实时监控和数据分析,这不仅能优化生产流程,还能显著提升产品质量和生产效率。其中,装配操作的实时监控对于制造…

html是什么?http是什么?

html Html是什么?http是什么? Html 超文本标记语言;负责网页的架构; http((HyperText Transfer Protocol)超文本传输协议; https(全称:Hypertext Transfer Protocol …

【Android面试八股文】Java中有几种引用关系,它们的区别是什么?

在Java中,引用关系主要分为以下几种: 强引用(Strong Reference)软引用(Soft Reference)弱引用(Weak Reference)虚引用(Phantom Reference) 这些引用类型的区别在于它们对垃圾回收的影响程度。下面是对每种引用类型的详细解释及代码示例: 1. 强引用(Strong Referen…

Linux系统使用Docker安装Dashy导航页结合内网穿透一键发布公网

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起,形成自己的导航…

VMware ESXi 8.0U2c macOS Unlocker OEM BIOS ConnectX-3 网卡定制版 (集成驱动版)

VMware ESXi 8.0U2c macOS Unlocker & OEM BIOS ConnectX-3 网卡定制版 (集成驱动版) 发布 ESXi 8.0U2 集成驱动版,在个人电脑上运行企业级工作负载 请访问原文链接:https://sysin.org/blog/vmware-esxi-8-u2-sysin/,查看最新版。原创作…

mysql的主从同步

MySQL的主从同步是一种数据复制技术,它允许将一个MySQL数据库服务器上的数据变化自动复制到一个或多个MySQL数据库服务器上。主从同步广泛用于高可用性、负载均衡、读写分离和数据备份。下面详细介绍MySQL主从同步的原理、配置步骤、常见问题及解决方法。 一、基本…

Mac M3 Pro 安装 Zookeeper-3.4.6

1、下载安装包 官方下载地址:https://archive.apache.org/dist/zookeeper/ 网盘下载地址:https://pan.baidu.com/s/1j6iy5bZkrY-GKGItenRB2w?pwdirrx 提取码: irrx 2、解压并添加环境变量 # 将安装包移动到目标目录 mv ~/Download/zookeeper-3.4.6.…