Pytorch中不同的Norm归一化详细讲解

news/2024/9/20 1:22:46/ 标签: pytorch, 人工智能, python, 机器学习源码分析

在这里插入图片描述

在做项目或者看论文时,总是能看到Norm这个关键的Layer,但是不同的Norm Layer具有不同的作用,准备好接招了吗?(本文结论全部根据pytorch官方文档得出,请放心食用)

一. LayerNorm

LayerNorm的公示如下:
y = x − E [ x ] Var ⁡ [ x ] + ϵ ∗ γ + β y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta y=Var[x]+ϵ xE[x]γ+β

Parameters(参数):

  • normalized_shape
  • eps(确保分母不为0)
  • elementwise_affine(布尔类型,是否要为每个元素添加一个可学习的仿射变换参数)
  • bias(布尔类型,在elementwise_affine为True时可选择为每个元素另外加上一个可学习的偏置项)

其中可变化的量即可学习的参数即 elementwise_affine涉及的权重以及bias。
归一化的维度是由normalized_shape来决定的。假设输入的张量形状是[B,C,H,W],此时常见的normalized_shape为[C,H,W]。换句话说,由于最后三个维度包含一张完整的图片信息,它会计算每个图片的 CxHxW 张量的均值和标准差,并进行归一化,使得这个张量在归一化后均值为 0,标准差为 1。
举个具体的例子来说明,假设输入张量 x 如下:
x = [ [ 1 2 3 4 5 6 7 8 9 10 11 12 ] [ 13 14 15 16 17 18 19 20 21 22 23 24 ] ] x=\left[\begin{array}{c}{\left[\begin{array}{cccc}1 & 2 & 3 & 4 \\5 & 6 & 7 & 8 \\9 & 10 & 11 & 12\end{array}\right]} \\{\left[\begin{array}{cccc}13 & 14 & 15 & 16 \\17 & 18 & 19 & 20 \\21 & 22 & 23 & 24\end{array}\right]}\end{array}\right] x= 159261037114812 131721141822151923162024
我们假设要对后两个维度进行归一化。
1. 计算均值
E [ x 1 ] = 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 12 = 78 12 = 6.5 \mathrm{E}\left[x_{1}\right]=\frac{1+2+3+4+5+6+7+8+9+10+11+12}{12}=\frac{78}{12}=6.5 E[x1]=121+2+3+4+5+6+7+8+9+10+11+12=1278=6.5
E [ x 2 ] = 13 + 14 + 15 + 16 + 17 + 18 + 19 + 20 + 21 + 22 + 23 + 24 12 = 222 12 = 18.5 \mathrm{E}\left[x_{2}\right]=\frac{13+14+15+16+17+18+19+20+21+22+23+24}{12}=\frac{222}{12}=18.5 E[x2]=1213+14+15+16+17+18+19+20+21+22+23+24=12222=18.5
2. 对每个样本的二维切片计算方差。
第一个样本:
Var ⁡ [ x 1 ] = ( 1 − 6.5 ) 2 + ( 2 − 6.5 ) 2 + ⋯ + ( 12 − 6.5 ) 2 12 = 11.9167 \operatorname{Var}\left[x_{1}\right]=\frac{(1-6.5)^{2}+(2-6.5)^{2}+\cdots+(12-6.5)^{2}}{12}=11.9167 Var[x1]=12(16.5)2+(26.5)2++(126.5)2=11.9167
第二个样本:
Var ⁡ [ x 2 ] = ( 13 − 18.5 ) 2 + ( 14 − 18.5 ) 2 + ⋯ + ( 24 − 18.5 ) 2 12 = 11.9167 \operatorname{Var}\left[x_{2}\right]=\frac{(13-18.5)^{2}+(14-18.5)^{2}+\cdots+(24-18.5)^{2}}{12}=11.9167 Var[x2]=12(1318.5)2+(1418.5)2++(2418.5)2=11.9167
3. 计算归一化后的值
方便起见,我们设定 ϵ=0:
y 1 = [ − 1.593 − 1.301 − 1.010 − 0.718 − 0.426 − 0.135 0.157 0.449 0.740 1.032 1.323 1.615 ] y_{1}=\left[\begin{array}{cccc}-1.593 & -1.301 & -1.010 & -0.718 \\-0.426 & -0.135 & 0.157 & 0.449 \\0.740 & 1.032 & 1.323 & 1.615\end{array}\right] y1= 1.5930.4260.7401.3010.1351.0321.0100.1571.3230.7180.4491.615
y 2 = [ − 1.593 − 1.301 − 1.010 − 0.718 − 0.426 − 0.135 0.157 0.449 0.740 1.032 1.323 1.615 ] y_{2}=\left[\begin{array}{cccc}-1.593 & -1.301 & -1.010 & -0.718 \\-0.426 & -0.135 & 0.157 & 0.449 \\0.740 & 1.032 & 1.323 & 1.615\end{array}\right] y2= 1.5930.4260.7401.3010.1351.0321.0100.1571.3230.7180.4491.615
4. 应用可学习的仿射变换(可选)

具体应用:
我们的核心目标是对一个完整对象利用LayerNorm,所以这是我们的第一目标。

  • NLP:
    在NLP领域中,最常见的单体对象就是word。常见的输入形状是[B,Seq_len,Word_dim],即batch_size,每个句子包含几个单词,每个单词的具体维度。所以我们在最后一个维度即单词维度进行归一化。
python"># NLP Example
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
# Activate module
layer_norm(embedding)
  • CV:
    视觉也不用多说了,最多的就是在后三个维度(即完整的一张图像上)进行归一化。
python">N, C, H, W = 20, 5, 10, 10
input = torch.randn(N, C, H, W)
# Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
# as shown in the image below
layer_norm = nn.LayerNorm([C, H, W])
output = layer_norm(input)

二. BatchNorm2d

公式与LayerNorm完全相同。但是在代码中其操作的维度不一样。LayerNorm是对整张图像进行归一化(操作后三个维度),而BatchNorm2d则是对通道进行归一化,比如说我们有256张图片作为一个批次,每张图片有3个通道为R,G,B,那么在R通道在归一化需要使用这256张图片的R通道。(G,B通道同理)

Parameters(参数):

  • num_features:定义了通道数。
  • eps :用于数值稳定性。
  • momentum: 控制运行均值和方差的更新速度。
  • affine: 决定是否有可学习的缩放和偏移参数。
  • track_running_stats: 控制是否在推理时使用运行时统计量。

示例代码如下:

python"># With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)

三. InstanceNorm2d

公式还是和上面的完全一样。InstanceNorm2d与BatchNorm2d非常相似,只不过InstanceNorm2d更进一步,它实现了单个样本单通道的归一化。
Parameters(参数):

  • num_features 定义了通道数。
  • eps 用于防止数值不稳定。
  • momentum 控制 running_mean 和 running_var 的更新速度(如果track_running_stats=True)。
  • affine 决定是否有可学习的缩放和偏移参数。
  • track_running_stats 决定是否在推理时使用累计的均值和方差,还是每次使用当前样本的统计量。

示例代码如下:
输入:(B,C,H,W) or (C,H,W)
输出:(B,C,H,W) or (C,H,W) 形状不变,当B=1时即(C,H,W),此时就是支持单个样本进行归一化的情况。

python"># Without Learnable Parameters
m = nn.InstanceNorm2d(100)
# With Learnable Parameters
m = nn.InstanceNorm2d(100, affine=True)
input = torch.randn(20, 100, 35, 45)
output = m(input)

四. GroupNorm

公式还是和上面三个一样。然而GroupNorm在Instance的基础上,可以将通道进行分组归一化。比如说一个样本共有8个通道,设置num_groups=2,那么1-4的channels,2-4的channels将被分组进行归一化。
Parameters(参数):

  • num_groups:定义了将通道分为多少组,每组内独立计算均值和方差。
  • num_channels:定义了输入数据的通道数,确保与 num_groups 匹配。
  • eps:防止除零错误的小值,确保计算稳定性。
  • affine:决定是否为每个通道学习仿射参数(缩放和偏移)。

示例代码:

python">input = torch.randn(20, 6, 10, 10)
# Separate 6 channels into 3 groups
m = nn.GroupNorm(3, 6)
# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
m = nn.GroupNorm(6, 6)
# Put all 6 channels into a single group (equivalent with LayerNorm)
m = nn.GroupNorm(1, 6)
# Activating the module
output = m(input)

纸上得来终觉浅,绝知此事要躬行!多分析源码,收获良多。


http://www.ppmy.cn/news/1519762.html

相关文章

github源码指引:C++嵌入式WEB服务器

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 相关专题: C嵌入式…

Element组件

文章目录 1 Element介绍2 快速入门3 Element组件3.1 Table表格3.1.1 组件演示3.1.2 组件属性详解 3.2 Pagination分页3.2.1 组件演示3.2.2 组件属性详解3.2.3 组件事件详解 3.3 Dialog对话框3.3.1 组件演示3.3.2 组件属性详解 3.4 Form表单3.4.1 组件演示 1 Element介绍 Eleme…

【代码随想录训练营第42期 Day46打卡 - 回文问题 - LeetCode 647. 回文子串 516.最长回文子序列

目录 一、做题心得 二、题目与题解 题目一:647. 回文子串 题目链接 题解1:动态规划 题解2:中心扩展法 题目二:516.最长回文子序列 题目链接 题解1:反转LCS 题解2:动态规划 三、小结 一、做题心得…

react antd点击table行时加选中背景色

在React中使用Ant Design的Table组件时,可以通过rowSelection属性来实现点击行时加亮背景色的功能。 import React from react; import { Table } from antd;const data [{key: 1,name: John Brown,age: 32,address: New York No. 1 Lake Park,},// ... 更多数据…

Leetcode102二叉树的层序遍历(java实现)

今天分享的题目是lee102题,题目的描述如下: 可能做到这道题的小伙伴写过其他关于二叉树的题目,但是一般是使用递归的方式做一个深度遍历,而层序遍历我们该如何做呢? 解题思路:使用一个队列来记录本层节点&a…

docker——compose容器编排!!!

⼀、Docker-compose 定义 1. docker compose 是 docker 官⽅的开源项⽬,负责实现对docker 容器集群的快速编排(容器,依赖,⽹络,挂载。。) 2. compose 是 docker 公司推出的⼀个⼯具软件,可以管理多个docker 容器组成…

手撕Python之序列类型

1.列表---list 索引的使用 当我们有一个数据的时候,我们怎么将这个数据存储到程序呢? 我们定义一个变量,将数据存储在变量中 那么如果有100个数据呢?要定义100个变量吗? 我们是可以用列表这个东西进行多个数据的存…

92. UE5 GAS RPG 使用C++创建GE实现灼烧的负面效果

在正常游戏里,有些伤害技能会携带一些负面效果,比如火焰伤害的技能会携带燃烧效果,敌人在受到伤害后,会接受一个燃烧的效果,燃烧效果会在敌人身上持续一段时间,并且持续受到火焰灼烧。 我们将在这一篇文章里…

[知识分享]华为铁三角工作法

在通信技术领域,尤其是无线通信和物联网领域,“华为铁三角”是华为公司内部的一种销售、交付和服务一体化的运作模式。这种模式强调的是以客户为中心,通过市场、销售、交付和服务三个关键环节的紧密协作,快速响应客户需求&#xf…

upload-labs通关攻略

Pass-1 这里上传php文件说不允许上传 然后咱们开启抓包将png文件改为php文件 放包回去成功上传 Pass-2 进来查看提示说对mime进行检查 抓包把这里改为image/jpg; 放包回去就上传成功了 Pass-3 这里上传php文件它说不允许上传这些后缀的文件 那咱们就可以改它的后缀名来绕过…

HarmonyOS开发实战( Beta5版)AOT编译使用指南

AOT编译使用指南 AOT(Ahead Of Time)即预先编译,在程序运行前,预先编译成高性能机器码,让程序在首次运行就能通过执行高性能机器码获得性能收益 方舟AOT编译器实现了PGO (Profile-Guided-Optimization)编译优化,即通过…

Spring Boot 中 `@Transactional` 注解使用示例

Transactional 注解在 Spring Boot 中用于管理事务。它确保在方法执行过程中,所有数据库操作要么全部成功,要么全部回滚,以维护数据的一致性。下面是一些使用 Transactional 的示例: 1. 在服务层使用 Transactional Service pub…

【React】Redux-toolkit 处理异步操作

安装 npm install reduxjs/toolkit react-redux创建 store src\store\index.js import { configureStore } from reduxjs/toolkit; import homeReducer from ./modules/home;const store configureStore({reducer: {home: homeReducer,}, });export default store;创建 Red…

UE4 BuildCookRun中的Archive的含义

在UE4中,Archive、Cook、Stage、Package、Build的次序是怎么样的? 整体打包过程如下: Build -> Cook-> Stage -> Package -> Archive。其中,Archive 的含义是从Staged目录中拷贝文件到一个额外的目录即Archive目录。被称为“归档…

四十五、【人工智能】【机器学习】- Robust Regression(稳健回归)

系列文章目录 第一章 【机器学习】初识机器学习 第二章 【机器学习】【监督学习】- 逻辑回归算法 (Logistic Regression) 第三章 【机器学习】【监督学习】- 支持向量机 (SVM) 第四章【机器学习】【监督学习】- K-近邻算法 (K-NN) 第五章【机器学习】【监督学习】- 决策树…

【Android】repositories和sourceSets指定了 `libs` 目录的区别

repositories { flatDir { dirs libs } } 这段代码的作用是告诉 Gradle 在指定的目录(这里是 libs 目录)中查找 JAR 文件或 AAR 文件。flatDir 是一种简单的文件目录结构,它不会解析子目录,只会查找指定目录中的文件。 reposito…

Arduino 串口打印小知识点

String str[]{"abc","defg","hijk","lm","n"}; int num; void setup() {Serial.begin(115200);numsizeof(str) /sizeof(str[2]);Serial.print("该数组 str[]的长度:");Serial.print(num); }void loop(…

Python编码系列—Python中的HTTPS与加密技术:构建安全的网络通信

🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…

p2p、分布式,区块链笔记:基于IPFS实现的数据库orbitdb笔记

orbitdb orbitdb :Peer-to-Peer Databases for the Decentralized Web 特性说明特点无服务器、分布式、p2p编程语言JavaScript对其他语言的支持A python client for the Orbitdb HTTP API,go-orbit-db, 让我们了解一下谁在使用 js-ipfs&…

jmeter 响应乱码

Jmeter在做接口测试的时候的,如果接口响应的内容中有中文,jmeter的响应内容很可能显示乱码,为了规避这种出现乱码的问题,就要对jmeter的响应结果进行编码处理。 打开jmeter进行接口、压力、性能等测试,出现以下乱码问…