RESNET

ops/2024/10/11 9:23:06/

ResNet

文章目录

  • ResNet
    • 主要内容
      • 开发背景
      • 解决两个问题:
        • 1. 梯度消失和梯度爆炸
        • 2. 退化问题:
      • 解决方法
        • 1. BN(Batch Normalization)层
        • 2. 残差块
    • Pytorch实现
      • BasicBlock
      • BottleNeck
      • ResNet

主要内容

开发背景

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的, 斩获2015年ImageNet竞赛中分类任务第一名, 目标检测第一名。 残差神经网络的主要贡献是发现了“退化现象(Degradation)”,并针对退化现象发明了 “直连边/短连接(Shortcut connection)”,极大的消除了深度过大的神经网络训练困难问题。神经网络的“深度”首次突破了100层、最大的神经网络甚至超过了1000层。

论文地址:Deep Residual Learning for Image Recognition

解决两个问题:

1. 梯度消失和梯度爆炸

梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0
梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

2. 退化问题:

随着层数的增加预测效果反而越来越差。

在这里插入图片描述

随着网络层数增加,出现了新的问题:退化问题,在训练集上准确率甚至下降了。这个不能解释为过拟合,因为过拟合表现为在训练集上表现更好才对。退化问题说明了深度网络不能很简单地被很好地优化。作者通过实验说明:通过浅层网络y=x 等同映射构造深层模型,结果深层模型并没有比浅层网络有更低甚至等同的错误率,推断退化问题可能是因为深层的网络很那难通过训练利用多层网络拟合同等函数。

解决方法

1. BN(Batch Normalization)层

为了解决梯度消失或梯度爆炸问题,ResNet论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。

2. 残差块

ResNet团队分别构建了带有“直连边(Shortcut Connection)”的ResNet残差块、以及降采样的ResNet残差块,区别是降采样残差块的直连边增加了一个1×1的卷积操作。对于直连边,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度。但是当维度不一致时,这就不能直接相加,通过添加1×1卷积调整通道数。这种残差学习结构可以通过前向神经网络+直连边实现, 而且整个网络依旧可以通过端到端的反向传播训练。结构如下图所示:

在这里插入图片描述

从数学角度解释:

在这里插入图片描述

深度残差网络。如果深层网络的后面那些层是恒等映射,那么模型就退化为一个浅层网络。所以要解决的就是学习恒等映射函数。但是直接让一些层去拟合一个潜在的恒等映射函数H(x) = x,比较困难,这可能就是深层网络难以训练的原因。但是,如果把网络设计为H(x) = F(x) + x。我们可以转换为学习一个残差函数F(x) = H(x) - x. 只要F(x)=0,就构成了一个恒等映射H(x) = x. 此外,拟合残差会更加容易。

总的来说,一是其导数总比原导数加1,这样即使原导数很小时,也能传递下去,能解决梯度消失的问题; 二是y=f(x)+x式子中引入了恒等映射(当f(x)=0时,y=x),解决了深度增加时神经网络的退化问题。

Pytorch实现

ResNet实现cifar100分类的代码放在GitHub: pytorch-cifar100了。该部分代码在项目中models/resnet.py里面。

参考github项目pytorch-cifar100

在ResNet中最重要的就是残差结构,一共提供了两种:

  • BasicBlock:两层3 * 3卷积用于实现18-layer和34-layer
  • BottleNeck:用于实现更深层的网络,如50-layer, 101-layer, 152-layer

BasicBlock

在这里插入图片描述

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, *args, **kwargs) -> None:super().__init__(*args, **kwargs)#residual functionself.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=stride,padding=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))self.shortcut = nn.Sequential()# 判断输出输入维度是否一致,不一致则使用1 * 1卷积进行升维或降维。 if stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))def forward(self,x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

BottleNeck

在这里插入图片描述

class BottleNeck(nn.Module):expansion = 4 def __init__(self, in_channels, out_channels, stride=1, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))self.shortcut = nn.Sequential()if stride != 1 or in_channels != BottleNeck.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))def forward(self,x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

ResNet

最后我们按照下图的网络结构来构建ResNet

在这里插入图片描述

class ResNet(nn.Module):def __init__(self, block, num_block, num_classes = 100, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.inchannels = 64self.conv1 = nn.Sequential(nn.Conv2d(3,64, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True))self.conv2_x = self._maker_layer(block, 64, num_block[0], 1)self.conv3_x = self._maker_layer(block, 128, num_block[1], 2)self.conv4_x = self._maker_layer(block, 256, num_block[2], 2)self.conv5_x = self._maker_layer(block, 512, num_block[3], 2)self.avg_pool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _maker_layer(self,block, out_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks -1)layers = []for stride in strides:layers.append(block(self.inchannels, out_channels, stride))self.inchannels = out_channels * block.expansionreturn nn.Sequential(*layers)def resnet18():return ResNet(BasicBlock, [2, 2, 2, 2])def resnet34():return ResNet(BasicBlock, [3, 4, 6, 3])def resnet50():return ResNet(BottleNeck, [3, 4, 6, 3])def resnet101():return ResNet(BottleNeck, [3, 4, 23, 3])def resnet152():return ResNet(BottleNeck, [3, 8, 36, 3]) 

http://www.ppmy.cn/ops/86952.html

相关文章

我在高职教STM32——串口通信(5)

大家好,我是老耿,高职青椒一枚,一直从事单片机、嵌入式、物联网等课程的教学。对于高职的学生层次,同行应该都懂的,老师在课堂上教学几乎是没什么成就感的。正因如此,才有了借助 CSDN 平台寻求认同感和成就感的想法。在这里,我准备陆续把自己花了很多心思的教学设计分享…

了解ChatGPT API

要了解如何使用 ChatGPT API,可以参考几个有用的资源和教程,这些资源能帮助你快速开始使用 API 进行项目开发。下面是一些推荐的资源: OpenAI 官方文档: 访问 OpenAI 的官方网站可以找到 ChatGPT API 的详细文档。这里包括了 API …

Java高手之路:每日一练,技能精进秘籍

目录 一、题目知识点java中有两种方式实现线程Servlet生命周期总结 一、题目 选自牛客网 1.后端获取数据,向前端输出过程中,以下描述正确的是 A.对于前端过滤过的参数,属于可信数据,可以直接输出到前端页面 B.对于从数据库获得的…

iOS object-C 解答算法:找到所有数组中消失的数字(leetCode-448)

找到所有数组中消失的数字(leetCode-448) 题目如下图:(也可以到leetCode上看完整题目,题号448) 光看题看可能有点难以理解,我们结合示例1来理解一下这道题. 有8个整数的数组 nums [4,3,2,7,8,2,3,1], 求在闭区间[1,8]范围内(即1,2,3,4,5,6,7,8)的数字,哪几个没有出现在数组 …

智云-一个抓取web流量的轻量级蜜罐

智云-一个抓取web流量的轻量级蜜罐 安装环境要求 apache php7.4 mysql8 github地址 https://github.com/xiaoxiaoranxxx/POT-ZHIYUN 系统演示

DiAD代码use_checkpoint

目录 1、梯度检查点理解2、 torch.utils.checkpoint.checkpoint函数 1、梯度检查点理解 梯度检查点(Gradient Checkpointing)是一种深度学习优化技术,它的目的是减少在神经网络训练过程中的内存占用。在训练深度学习模型时,我们需…

JAVA零基础学习3(Scanner类,字符串,StringBuilder,StringJoinder,ArrayList成员方法)

JAVA零基础学习3 Scanner类输入示例代码代码解释完整代码1. 读取字符串2. 读取整数3. 读取浮点数4. 读取布尔值5. 读取单个单词6. 读取长整型数7. 读取短整型数8. 读取字节数注意事项总结 API 字符串解释示例解释解决方法示例:使用 StringBuilder String…

计算机基础(day1)

1.什么是内存泄漏?什么是内存溢出?二者有什么区别? 2.了解的操作系统有哪些? Windows,Unix,Linux,Mac 3. 什么是局域网,广域网? 4.10M 兆宽带是什么意思?理论…