Batch Size 不同对evaluation performance的影响

devtools/2024/10/21 5:53:11/

目录

    • 问题描述
    • 如果是bug
    • batch size的设置问题
    • 尝试使用GroupNorm解决batchsize不同带来的问题
    • 参考文章

问题描述

深度学习网络训练时,使用较小的batch size训练网络后,如果换用较大的batch size进行evaluation,网络的预测能力会显著下降。如果evaluation的batch size和train的batch size大小相同时,则不会遇到此类问题。

PyTorch Forums – Performance highly degraded when eval() is activated in the test phase

如果是bug

  1. metric会根据batch_size的大小变化(但并不显著),metric按每个batch分别进行计算
  2. 缺失model.eval()指令:with torch.no_grad() 对dropout和batch normalization不起固定作用。
    1. nn.Dropout层参数不会固定
    2. nn.BatchNorm2d()
      1. PyTorch – BatchNorm2d BatchNorm2d函数中的参数track_running_stats:trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
      2. trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_meanrunning_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。
  3. Dataloader中加入了随机处理,例如RandomCrop
  4. 没有固定随机种子

batch size的设置问题

如果batch size较小,会导致上述running_mean和running_var不准确。参考文章Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,当模型训练完成后,

x ^ = x − E [ x ] V a r [ x ] + ϵ \hat{x} = \frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} x^=Var[x]+ϵ xE[x]

其中, V a r [ x ] = m m − 1 E B [ σ B 2 ] Var[x]=\frac{m}{m-1}E_B[\sigma_B^2] Var[x]=m1mEB[σB2],the expectation is over training mini-batches of size m and σ B 2 \sigma_B^2 σB2 are their sample variances.

尝试使用GroupNorm解决batchsize不同带来的问题

归一化的分类

<a class=归一化的分类" />
LN 和 IN 在视觉识别上的成功率都是很有限的,对于训练序列模型(RNN/LSTM)或生成模型(GAN)很有效。

所以,在视觉领域,BN用的比较多,GN就是为了改善BN的不足而来的。

GN 把通道分为组,并计算每一组之内的均值和方差,以进行归一化。GN 的计算与批量大小无关,其精度也在各种批量大小下保持稳定。可以看到,GN和LN很像。

参考文章

pytorch 每次测试结果不同
Batch Normalization
深度学习中的组归一化(GroupNorm)


http://www.ppmy.cn/devtools/56983.html

相关文章

计算机毕业设计PyFlink+Spark+Hive民宿推荐系统 酒店推荐系统 民宿酒店数据分析可视化大屏 民宿爬虫 民宿大数据 知识图谱 机器学习

本科毕业设计(论文) 开题报告 学院 &#xff1a; 计算机学院 课题名称 &#xff1a; 民宿数据可视化分析系统的设计与实现 姓名 &#xff1a; 庄贵远 学号 &#xff1a; 2020135232 专业 &#xff1a; 数据科学与大数据技术 班级 &#xff1a; 20大数据本科2…

C++之STL(十)

1、适配器 2、函数适配器 #include <iostream> using namespace std;#include <algorithm> #include <vector> #include <functional>bool isOdd(int n) {return n % 2 1; } int main() {int a[] {1, 2, 3, 4, 5};vector <int> v(a, a 5);cou…

SQL小白超详细入门教程

SQL入门教程 一、SQL概述 SQL&#xff08;Structured Query Language&#xff09;是一种用于操作关系数据库&#xff08;如MySQL、Oracle、SQL Server等&#xff09;的编程语言。它是一门ANSI&#xff08;美国国家标准化组织&#xff09;的标准计算机语言&#xff0c;用于访问…

简明万年历编制(C语言)

简明万年历编制&#xff08;C语言 &#xff09; 编制万年历的要素&#xff1a; 农历公历对照&#xff0c;显示星期&#xff0c;农历干支年&#xff0c;当年生肖&#xff0c;国定节假日&#xff0c;寒天九九&#xff0c;暑日三伏&#xff0c;入梅出梅&#xff0c;节气时间&#…

机器学习原理和代码实现专辑

1. 往期文章推荐 1.【机器学习】图神经网络(NRI)模型原理和运动轨迹预测代码实现 2. 【机器学习】基于Gumbel-Sinkhorn网络的“潜在排列问题”求解 3. 【机器学习】基于Gumbel Top-k松弛技术的图形采样 4. 【机器学习】基于Softmax松弛技术的离散数据采样 5. 【机器学习】正则…

期末考试题-通过HTML编程Vue3选项式:简易购物车

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><!-- 引用 element-plus 样式 --><!-- 注意&#xff1a;复…

Kafka 和 RabbitMQ对比

Kafka和RabbitMQ是两种广泛使用的消息队列系统&#xff0c;它们在设计理念、架构和功能上有很多相似之处&#xff0c;但也有许多显著的区别。以下是两者之间的异同点&#xff0c;以表格的形式详细阐述&#xff1a; 特性KafkaRabbitMQ消息模型基于日志&#xff08;Log-based&am…

VSCode + GDB + J-Link 单片机程序调试实践

VSCode GDB J-Link 单片机程序调试实践 本文介绍如何创建VSCode的调试配置&#xff0c;如何控制调试过程&#xff0c;如何查看修改各种变量。 安装调试插件 在 VSCode 扩展窗口搜索安装 Cortex-Debug插件 创建调试配置 在 Run and Debug 窗口点击 create a launch.json …