秃姐学AI系列之:批量归一化 + 代码实现

目录

批量归一化

核心想法

批归一化在做什么

总结

代码实现

从零实现

创建一个正确的BatchNorm层

应用BatchNorm于LeNet模型

简单实现

QA


批量归一化

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。

因为数据在网络最开始,而损失在结尾。训练的过程是一个前向传播的过程,而参数更新是一个从后往前的更新方式。会导致越靠近损失的参数,梯度更新越大(因为是一些很小的值不断的乘,会变得越来越小),而最终导致后面的层训练的比较快

虽然底部层训练的慢,但是底部层一变化,所有的都得跟着变。导致最后的那些层需要重新学习多次!从而导致收敛变慢。

批量归一化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。 再结合 残差块,批量归一化 使得研究人员能够训练100层以上的网络。

虽然这个思想不新了,但是这个层确实是近几年出来的,大概在16年左右。当你要做很深的神经网络之后,会发现加入批量归一化,效果很好。基本成为现在不可避免的一个层了。

核心想法

当我们训练时,中间层中的变量(例如,多层感知机中的仿射变换输出)可能具有更广的变化范围:不论是沿着从输入到输出的层,跨同一层中的单元,或是随着时间的推移,模型参数的随着训练更新变幻莫测。

所以批量归一化的思想就是,我固定住分布,不管哪一层的 输出 还是 梯度,都符合某一个分布。使得网络没有特别大的转变,那么在学习细微的数值的时候就比较容易。当然具体什么分布,分布细微的东西可以再调整。

  • 固定小批量里面的 均值 方差

  • 然后在做额外的调整(可学习的参数)

式子中的 \mu _{B} 和 \sigma _{B} 是根据数据学出来的,而 \gamma 和 \beta 是一个可学习的参数

这两个参数的意义是  假设直接把数据设为均值为0,方差为1 不是那么适合,那就可以去需欸一个新的均值和方差去更加适应网络

但是会限制住 \gamma 和 \beta 不要变化的过于猛烈

  • 可学习的参数为 \gamma 和 \beta
  • 作用在
    • 全连接层和卷积层输出上,激活函数前
    • 全连接层和卷积层输入上,对输入做一个均值变化,使得输入的 方差、均值 比较好

为什么要放在激活函数之前:ReLU把你所有东西都变成正数,如果放在ReLU之后,批归一化层又给你算的奇奇怪怪的

可以认为批归一化是个线性变换

  • 对全连接层,作用在特征维度
  • 对于卷积层,作用在通道维度

批归一化在做什么

  • 最初论文是想用它来减少内部协变量转移
  • 后续有论文指出它可能就是通过在每个小批量里加入噪音来控制模型复杂度

认为 \hat{\mu _{B}} 是随机偏移(当前样本计算而来),\hat{\sigma _{B}} 是随机缩放(当前样本计算而来)

  • 因此没必要和丢弃法混合使用 

按照上面的思路的话,本来批归一化就是一个控制模型复杂度的方法,丢弃法也是。在 批归一化 上再加 丢弃,可能就没那么有用了。 

总结

  • 批量归一化:固定小批量中的均值和方差,然后学习出适合的偏移和缩放
  • 可以加速收敛速度,但一般不改变模型的精度
  • 在模型训练过程中,批量归一化不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量归一化在全连接层和卷积层的使用略有不同。
  • 批量归一化层 和 丢弃法 一样,在训练模式和预测模式下计算不同。
  • 批量归一化 有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

代码实现

从零实现

详细注释版

import torch
from torch import nn
from d2l import torch as d2l# 参数(X, 学习的参数:gamma、beta,预测用的全局的均值和方差:moving_mean、moving_var,极小值:eps,用来更新全局均值和方差的参数:momentum,通常取0.9 or 固定数字)
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差,因为预测的时候可能没有批量,只有一张图片 or 一个样本X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:   #X.shape = 2:全连接层# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)  #(1,n)的行向量,按行求均值 = 计算每一列的均值var = ((X - mean) ** 2).mean(dim=0)   # 依旧是按行,所以我们的方差也是行向量else:    # X.shape = 4:卷积层# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)  #(1,n,1,1)的形状var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  #(1,n,1,1)的形状# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差,最终会无限逼近真实的数据集上的全集均值、方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data

创建一个正确的BatchNorm层

我们现在可以创建一个正确的 BatchNorm 层。 这个层将保持适当的参数:拉伸 gamma 和偏移 beta,这两个参数将在训练过程中更新。 此外,我们的层将保存均值和方差的移动平均值,以便在模型预测期间随后使用。

撇开算法细节,注意我们实现层的基础设计模式。

  • 通常情况下,我们用一个单独的函数定义其数学原理,比如说 batch_norm。
  • 然后,我们将此功能集成到一个自定义层中,其代码主要处理数据移动到训练设备(如GPU)、分配和初始化任何必需的变量、跟踪移动平均线(此处为均值和方差)等问题。

为了方便起见,我们并不担心在这里自动推断输入形状,因此我们需要指定整个特征的数量。 不用担心,深度学习框架中的 批归一化 API 将为我们解决上述问题。

class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0,需要被迭代self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1,不需要迭代self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

应用BatchNorm于LeNet模型

回想一下,批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))  # 没有必要对输出计算归一化

简单实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

QA

  • Xavier 和 batch normalization 以及其他正则化手段有什么区别

Xavier 是选取比较好的初始化方法,使得网络在开始的时候比较稳定,但不能保证之后

BN 保证在整个模型训练的时候都强行的在每一层后面做归一化(其实不应该叫normalization,学深度学习的数学没学好,应该是归一化,不是正则化)

  • BN是不是一般用于深层网络,浅层MLP加上BN效果好像不好

BN对深度网络效果更好,对于浅层网络没有太多太多用处,因为只有网络深度起来了才会出现我们上面提到的后面的层更快的训练好,从而被反复作废、训练、作废、训练的情况

  •  BN是做了线性变换,和加一个线性层有什么区别?

没啥太大的区别,只能说如果加了一个线性层,线性层可能不一定能学到 BN 学到的那些东西。只是一个线性层,做一个线性变换,没办法给数值做变化(均值为1,方差为0)

  • layerNorm 和 batchNorm的区别

一般来说,layerNorm 用于比较大的网络,作用在图上,batchNorm就为1,做不了batchNorm


http://www.ppmy.cn/news/1516293.html

相关文章

【GH】【EXCEL】P4: Chart

文章目录 data and chartdonut chart (radial chart)Radial Chart bar chartBar Chart line chartLine Chart Scatter ChartScatter Chart Surface ChartSurface Chart Chart DecoratorsChart Decorators Chart GraphicsChart Graphics data and chart donut chart (radial cha…

[数据集][目标检测]流水线物件检测数据集VOC+YOLO格式9255张26类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):9255 标注数量(xml文件个数):9255 标注数量(txt文件个数):9255 标注…

【前端基础篇】CSS基础速通万字介绍(下篇)

文章目录 前言背景属性背景颜色背景图片背景平铺背景位置背景尺寸 圆角矩形生成圆形生成圆角矩形 Chrome调试工具打开方式标签页含义elements标签页使用 元素显示模式块级元素行内元素/内联元素行内元素和块级元素的区别 盒模型边框内边距外边距 块级元素水平居中去除浏览器默认…

PPP简介

介绍PPP特性的定义和目的。 定义 PPP(Point-to-Point Protocol)协议是一种点到点链路层协议,主要用于在全双工的同异步链路上进行点到点的数据传输。 目的 PPP协议是在串行线IP协议SLIP(Serial Line Internet Protocol&#x…

【Linux —— 线程同步 - 条件变量】

Linux —— 线程同步 - 条件变量 条件变量的概念互斥量与条件变量的关系条件变量的操作代码示例 条件变量的概念 条件变量是一种用于线程间同步的机制,主要用于协调线程之间的执行顺序,允许线程在某个条件不满足时进入等待状态,直到其他线程通…

Docker的概述及如何启动docker的镜像、远程管理宿主机的docker进程

一、概述: 1、Docker 是什么? Docker 是⼀个开源的应⽤容器引擎,可以实现虚拟化,完全采用“沙盒”机制,容器之间不会存在任何接口。 2、Docker 和虚拟机的区别: 1)启动速度:Dock…

【Material-UI】RadioGroup组件:单选按钮组详解

文章目录 一、RadioGroup 组件概述1. 组件介绍2. 基本用法 二、RadioGroup 的关键特性1. 布局方向2. 受控组件3. 表单集成 三、RadioGroup 的实际应用场景1. 用户偏好选择2. 付款方式选择 四、总结 Material-UI 是一个广泛使用的 React UI 框架,提供了丰富的组件库以…

Linux系统性能调优技巧

Linux系统性能调优是一个复杂而细致的过程,它涉及到硬件、软件、配置、监控和调优策略等多个方面。以下将详细阐述Linux系统性能调优的技巧: 一、硬件优化 CPU优化 选择适合的CPU:根据应用需求选择多核、高频的CPU,以满足高并发…

OpenGuass under Ubuntu_22.04 install tutorial

今天开始短学期课程:数据库课程设计。今天9点左右在SL1108开课,听陈老师讲授了本次短学期课程的要求以及任务安排,随后讲解了国产数据库的三层架构的逻辑。配置了大半天才弄好,放一张成功的图片,下面开始记录成功的步骤…

【uniapp】图片合成并导入base64

两张图片合成,宽度固定,高度根据图片自适应 调用 this.mergeImgs(this.imgList).then((res)>{console.log(res,图片base64) })方法 mergeImgs(imgList) {// 图片合成return new Promise((resolve, reject) > {Promise.all(this.fileDtoList.map(im…

在银河麒麟服务器V10上源码编译安装mysql-5.7.42-linux-glibc2.12-x86_64

在银河麒麟服务器V10上源码编译安装mysql-5.7.42-linux-glibc2.12-x86_64 一、卸载MariaDB(如果已安装)二、下载MySQL源码包并解压三、安装编译所需的工具和库四、创建MySQL的安装目录及数据库存放目录五、编译安装MySQL六、配置MySQL七、设置环境变量八…

使用canal增量同步ES索引库数据

Canal增量数据同步利器 Canal介绍 canal主要用途是基于 MySQL 数据库增量日志解析,并能提供增量数据订阅和消费,应用场景十分丰富。 github地址:https://github.com/alibaba/canal 版本下载地址:https://github.com/alibaba/c…

8月15日

上午开会 rag继续 异构大模型 狂野飙车9之前的账号终于找回来了 下午 关于minicpm的代码 minicpm-v 大模型预训练论文&方法总结 - 知乎 (zhihu.com) 这里有讲解的代码 发现还是先推荐把llava的掌握好了之后再看minicpm 多模态大模型LLaVA模型讲解——transformers源…

ARM——驱动——内核编译

一、内核的介绍 Linux内核是Linux操作系统的核心内容,它负责管理系统的硬件资源,并为上层的应用程序提供接口。(在上文都有所介绍) 功能: 进程管理:内核负责创建、调度、同步和终止进程。它还管理进程间的…

递归和迭代

递归可以用迭代来解决,但迭代不一定能用递归来实现。 递归可以用栈来实现,保存函数的参数和返回值。。eg:深度优先搜索、斐波那契数列迭代就是循环(如for、while) 递归转迭代 递归本质上是通过函数调用自身来解决问…

汽车冷却液温度传感器

1、冷却液温度传感器的功能 发动机冷却液温度传感器,也称为ECT,是帮助保护发动机,提高发动机工作效率以及帮助发动机稳定运行的非常重要的传感器之一。 发动机冷却液温度 (ECT) 传感器用于测量发动机的冷却液温度&…

基于UDS的Flash 刷写——BootLoad刷写流程详解

从0开始学习CANoe使用 从0开始学习车载测试 相信时间的力量 星光不负赶路者,时光不负有心人。 目录 流程概述UDS流程详解释前编程①诊断会话控制 - 切换到扩展会话(10 03)②例程控制-预编程条件检查(31 01 02 03)③DTC…

QT中使用QAxObject类读取xlsx文件内容并显示在ui界面

一、源码 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow>QT_BEGIN_NAMESPACE namespace Ui { class MainWindow; } QT_END_NAMESPACEclass MainWindow : public QMainWindow {Q_OBJECTpublic:MainWindow(QWidget *parent nullptr);~MainWindow();pr…

下载B站视频作为PPT素材

下载B站视频作为PPT素材 1. 下载原理2. 网页分析3. 请求页面&#xff0c;找到数据4. 数据解析5. 音频、视频下载6. 合并音频与视频7. 完整代码 其实使用爬虫也不是第一次了&#xff0c;之前从网站爬过图片&#xff0c;下载过大型文件&#xff0c;如今从下载视频开始才想到要写一…

【C++算法/学习】位运算详解

✨ 忍能对面不相识&#xff0c;仰面欲语泪现流 &#x1f30f; &#x1f4c3;个人主页&#xff1a;island1314 &#x1f525;个人专栏&#xff1a;算法学习 &#x1f680; 欢迎关注&#xff1a;&#x1f44d;点赞 &…