手撕/手写/自己实现 BN层/batch norm/BatchNormalization python torch pytorch

news/2025/2/14 4:02:50/

计算过程

在卷积神经网络中,BN 层输入的特征图维度是 (N,C,H,W), 输出的特征图维度也是 (N,C,H,W)
N 代表 batch size
C 代表 通道数
H 代表 特征图的高
W 代表 特征图的宽

我们需要在通道维度上做 batch normalization,
在一个 batch 中,
使用 所有特征图 相同位置上的 channel 的 所有元素,计算 均值和方差,
然后用计算出来的 均值和 方差,更新对应特征图上的 channel , 生成新的特征图

如下图所示:
对于4个橘色的特征图,计算所有元素的均值和方差,然后在用于更新4个特征图中的元素(原来元素减去均值,除以方差)
![[attachments/BN示意图.png]]

代码

def my_batch_norm_2d_detail(features, eps=1e-5):'''这个函数的写法是为了帮助理解 BatchNormalization 具体运算过程实际使用时这样写会比较慢'''n,c,h,w = features.shapefeatures_copy = features.clone()running_var = torch.randn(c)running_mean = torch.randn(c)for ci in range(c):# 分别 处理每一个通道mean = 0 # 均值var = 0 # 方差_sum = 0 # 对一个 batch 中,特征图相同位置 channel 的每一个元素求和for ni in range(n):            for hi in range(h):for wi in range(w):_sum += features[ni,ci, hi, wi]mean = _sum / (n * h * w) running_mean[ci] = mean_sum = 0# 对一个 batch 中,特征图相同位置 channel 的每一个元素求平方和,用于计算方差 for ni in range(n):            for hi in range(h):for wi in range(w):_sum += (features[ni,ci, hi, wi] - mean) ** 2var = _sum / (n * h * w )running_var[ci] = _sum / (n * h * w - 1)# 更新元素for ni in range(n):            for hi in range(h):for wi in range(w):features_copy[ni,ci, hi, wi] = (features_copy[ni,ci, hi, wi] - mean) / torch.sqrt(var + eps) return features_copy, running_mean, running_varif __name__ == "__main__":torch.set_printoptions(precision=7)torch_bn = nn.BatchNorm2d(4)  # 设置 channel 数torch_bn.momentum = Nonefeatures = torch.randn(4, 4, 2, 2) # (N,C,H,W)torch_bn_output = torch_bn(features)    my_bn_output, running_mean, running_var = my_batch_norm_2d_detail(features)        print(torch.allclose(torch_bn_output, my_bn_output))print(torch.allclose(torch_bn.running_mean, running_mean))print(torch.allclose(torch_bn.running_var, running_var))

注意事项

方差计算

需要注意的是,在训练的过程中,方差有两种不同的计算方式,

在训练时,用于更新特征图的是 有偏方差
而 running_var 的计算,使用的是 无偏方差
在这里插入图片描述

相关链接

官方人员手写BN

"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d@author: ptrblck
"""import torch
import torch.nn as nndef compare_bn(bn1, bn2):err = Falseif not torch.allclose(bn1.running_mean, bn2.running_mean):print('Diff in running_mean: {} vs {}'.format(bn1.running_mean, bn2.running_mean))err = Trueif not torch.allclose(bn1.running_var, bn2.running_var):print('Diff in running_var: {} vs {}'.format(bn1.running_var, bn2.running_var))err = Trueif bn1.affine and bn2.affine:if not torch.allclose(bn1.weight, bn2.weight):print('Diff in weight: {} vs {}'.format(bn1.weight, bn2.weight))err = Trueif not torch.allclose(bn1.bias, bn2.bias):print('Diff in bias: {} vs {}'.format(bn1.bias, bn2.bias))err = Trueif not err:print('All parameters are equal!')class MyBatchNorm2d(nn.BatchNorm2d):def __init__(self, num_features, eps=1e-5, momentum=0.1,affine=True, track_running_stats=True):super(MyBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)def forward(self, input):self._check_input_dim(input)exponential_average_factor = 0.0if self.training and self.track_running_stats:if self.num_batches_tracked is not None:self.num_batches_tracked += 1if self.momentum is None:  # use cumulative moving averageexponential_average_factor = 1.0 / float(self.num_batches_tracked)else:  # use exponential moving averageexponential_average_factor = self.momentum# calculate running estimatesif self.training:mean = input.mean([0, 2, 3])# use biased var in trainvar = input.var([0, 2, 3], unbiased=False)n = input.numel() / input.size(1)with torch.no_grad():self.running_mean = exponential_average_factor * mean\+ (1 - exponential_average_factor) * self.running_mean# update running_var with unbiased varself.running_var = exponential_average_factor * var * n / (n - 1)\+ (1 - exponential_average_factor) * self.running_varelse:mean = self.running_meanvar = self.running_varinput = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))if self.affine:input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]return input# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)compare_bn(my_bn, bn)  # weight and bias should be different
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
compare_bn(my_bn, bn)# Run train
for _ in range(10):scale = torch.randint(1, 10, (1,)).float()bias = torch.randint(-10, 10, (1,)).float()x = torch.randn(10, 3, 100, 100) * scale + biasout1 = my_bn(x)out2 = bn(x)compare_bn(my_bn, bn)torch.allclose(out1, out2)print('Max diff: ', (out1 - out2).abs().max())# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):scale = torch.randint(1, 10, (1,)).float()bias = torch.randint(-10, 10, (1,)).float()x = torch.randn(10, 3, 100, 100) * scale + biasout1 = my_bn(x)out2 = bn(x)compare_bn(my_bn, bn)torch.allclose(out1, out2)print('Max diff: ', (out1 - out2).abs().max())

http://www.ppmy.cn/news/41444.html

相关文章

Text to image论文精读GigaGAN: 生成对抗网络仍然是文本生成图像的可行选择

GigaGAN是Adobe和卡内基梅隆大学学者们提出的一种新的GAN架构,作者设计了一种新的GAN架构,推理速度、合成高分辨率、扩展性都极其有优势,其证明GAN仍然是文本生成图像的可行选择之一。 文章链接:https://arxiv.org/abs/2303.0551…

【AIGC未来的发展方向】面向人工智能的第一步,一文告诉你人工智能是什么以及未来的方向分析

人工智能的概念 当人们提到“人工智能(AI)”时,很多人会想到机器人和未来世界的科幻场景,但AI的应用远远不止于此。现在,AI已经广泛应用于各种行业和生活领域,为我们带来了无限可能。 AI是一个广泛的概念…

《程序员面试金典(第6版)》面试题 08.14. 布尔运算(动态规划,分治,递归,难度hard++)

题目描述 给定一个布尔表达式和一个期望的布尔结果 result,布尔表达式由 0 (false)、1 (true)、& (AND)、 | (OR) 和 ^ (XOR) 符号组成。实现一个函数,算出有几种可使该表达式得出 result 值的括号方法。 示例 1: 输入: s “1^0|0|1”, result 0 …

干货分享 - MatLab || 与LaTeX的混合使用指南

目录 1、前言 2、Latex基础 3、Latex尝鲜 4、Latex在MatLab中换行 5、Latex在MatLab中小花招 6、附录1:Tex对照表 7、附录2:常用Tex字符 1、前言 LaTeX语言作为应用最广泛的Tex格式,Tex这种语言具有简单排版和程序设计的功能。 利用…

I/O软件的层次结构及其概念

I/O软件在计算机的什么部位 I/O (Input/Output) 软件是运行在计算机操作系统中的一种软件,用于管理计算机与外部设备之间的输入和输出数据交换。这种软件通常与计算机的驱动程序和硬件设备交互,以实现输入和输出数据的传输和处理。 因为 I/O 软件是操作…

同程旅行面试_4/14

今天准备了一上午-下午4:30 进行了30min面试 问八股居多,主要考题如下: 1. 说一下 和 的区别(先进行类型转换,再进行值转换?) 2.说一下flex布局 3.一栏固定 另一栏随着宽度改变而改变 4.说一下移动端适…

12.Union 结构

文章目录十二、Union 结构十二、Union 结构 有时需要一种数据结构,不同的场合表示不同的数据类型。比如,如果只用一种数据结构表示水果的“量”,这种结构就需要有时是整数(6个苹果),有时是浮点数&#xff…

CoreDNS 性能优化

CoreDNS 作为 Kubernetes 集群的域名解析组件,如果性能不够可能会影响业务,本文介绍几种 CoreDNS 的性能优化手段。合理控制 CoreDNS 副本数考虑以下几种方式:根据集群规模预估 coredns 需要的副本数,直接调整 coredns deployment 的副本数:k…