【深度学习】EMA指数移动平均

devtools/2025/1/8 8:01:59/

深度学习中,经常会使用指数移动平均模型(Exponential Moving Average Model,EMA)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。

这里的平均是是一种给予近期数据更高权重的平均方法

EMA是一种用于平滑时间序列数据的技术。它通过对数据进行加权平均来减少噪音和波动,从而提取出数据的趋势。

 在深度学习中,EMA 常常用于模型的参数更新和优化过程中。它可以帮助模型在训练过程中更稳定地收敛,并提高模型的泛化能力。

一、广义EMA

假设我们有 n 个数据:$\left[\theta_1, \theta_2, \ldots, \theta_n\right]$

  • 普通的平均数: $\bar{v}=\frac{1}{n} \sum_{i=1}^n \theta_i$
  • EMA:$v_t=\beta \cdot v_{t-1}+(1-\beta) \cdot \theta_t$ ,其中,$v_t$ 表示前 $t$ 时刻的平均值 $\left(v_0=0\right)$$ \beta$是加权权重值(一般设为0.9-0.999)。

Andrew Ng 在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 $1 /(1-\beta)$ 个时刻 $v$ 值的平均。

普通的过去$n$ 时刻的平均是这样的:

$v_t=\frac{(n-1) \cdot v_{t-1}+\theta_t}{n}$

类比EMA,可以发现当 $\beta=\frac{n-1}{n}$ 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。

实际上,EMA计算时,过去 $1 /(1-\beta)$ 个时刻之前的数值平均会衰变到 $\frac{1}{e}$的加权比例,证明如下。

如果将这里的 $v_t$ 展开,可以得到:

$ v_t=\beta^n v_{t-n}+(1-\beta)\left(\beta^{n-1} \theta_{t-n+1}+\ldots+\beta^0 \theta_t\right) $

其中,$n=\frac{1}{1-\beta}$,代入可以得到 $\beta^n=\beta^{\frac{1}{1-\beta}} \approx \frac{1}{e}$

二、在深度学习的优化中的EMA

深度学习中,EMA 常常用于以下两个方面:

  • 参数更新:在模型训练过程中,通常会使用梯度下降等优化算法来更新模型的参数。而使用 EMA 更新参数时,可以通过计算参数的指数移动平均值来更新参数,从而减少参数更新的噪音和波动。
  • 模型预测:在模型预测阶段,可以使用训练过程中得到的参数的指数移动平均值来进行预测。这样可以减少模型预测结果的波动,提高预测的稳定性。

深度学习的优化过程中, \theta_t 是$t$时刻的模型权重weights,$v_t$ 是 $t$ 时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练,而是用于后续的决策和评估。基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒

EMA通过对参数进行平滑处理,使得较新的参数值对应的权重较大,较旧的参数值对应的权重较小。这样可以更好地反映参数的变化趋势,并在模型训练中提供更稳定的更新。

下面是一种常见的使用EMA进行参数更新和优化的方法,称为EMA更新策略

  1.         初始化模型参数:初始化模型的参数为初始值。
  2.         初始化EMA:将EMA的初始值设置为与模型参数相同的初始值。
  3.         迭代训练:对于每个训练迭代(epoch):

                a. 计算梯度:根据训练数据和当前的模型参数,计算模型的梯度。

                b. 更新参数:使用梯度下降或其他优化算法更新模型参数。

                c. 更新EMA:更新EMA的值,将当前的模型参数与EMA的上一个值进行平滑处理

                d. 更新模型参数:平滑后的EMA值作为新的模型参数值

预测阶段,可以使用指数移动平均模型来平滑模型参数,并基于平滑后的参数进行预测。

通过使用指数移动平均模型,在模型预测过程中,可以减少参数的波动,提高预测结果的稳定性。这有助于降低模型对噪音和异常值的敏感性,提高预测的准确性和鲁棒性。

三、EMA的代码实现

实现适用于任何类型模型的指数移动平均(EMA):


EMA权重将在验证期间使用,并与原始模型权重分开存储。  如何使用EMA: 

  • 有时,最后的EMA检查点可能不是最佳的,因为EMA权重的指标可能会随时间出现长期振荡。参见 https://github.com/rwightman/pytorch-image-models/issues/102 
  • 批量归一化(Batch Norm)层和可能的其他类型的归一化层不需要在最后更新。参见以下讨论:  https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 和 https://github.com/rwightman/pytorch-image-models/issues/224 
  • 对于目标检测,通常 SWA(随机权重平均)效果更好。参见 https://github.com/timgaripov/swa/issues/16 

实现细节: 

  • 参见 Pytorch Lightning 中的 EMA:https://github.com/PyTorchLightning/pytorch-lightning/issues/10914 
  • 在多 GPU 环境中,我们广播 EMA 权重和原始权重,以便在内存中只保留一份副本。 
  • 当将 EMA 权重存储在 CPU + 固定内存上时,这一点尤其重要,因为固定内存是有限资源。 
  • 此外,我们希望避免在非 0 级别的重复操作,以减少抖动并提高性能。 

reference:

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

深度学习之指数移动平均模型(EMA)介绍_ema模型-CSDN博客

TDNetGen/README.md at main · tsinghua-fib-lab/TDNetGen · GitHub


http://www.ppmy.cn/devtools/148841.html

相关文章

32单片机从入门到精通之硬件架构——总线系统(二)

一个真正强大的人,不会把太多心思花在取悦和亲附别人上面,所谓的圈子、资源,都只是衍生品,最重要的是提高自己的内功。 你要默默做好你该做的事情,等你变得足够优秀时,你想要的都会主动来找你,你…

java 转义 反斜杠 Unexpected internal error near index 1

代码: String str"a\\c"; //出现异常,Unexpected internal error near index 1 //System.out.println(str.replaceAll("\\", "c"));//以下三种都正确 System.out.println(str.replace(\\, c)); System.out.println(str.r…

python实战(十三)——基于Bert+HDBSCAN的微博热搜数据挖掘

一、任务目标 众所周知,微博热搜几乎是许多网友的主要新闻来源,上面实时更新着当前最新的社会消息,其时效性甚至比每天晚上播出的新闻联播还要强。这篇文章,我们使用来自Kaggle的《MicroBlog-Hot-Search-Labeled》数据集&#xff…

spark on hive 参数

set hive.execution.enginespark; set spark.app.nametest9999; set spark.executor.cores5; set spark.executor.memory20G; set spark.executor.instances5; set spark.driver.memory5G; set spark.memory.fraction0.9; –定义了 Spark 作业中每个 stage 的默认 task 数量。 …

python代码实现了一个金融数据处理和分析的功能,主要围绕国债期货及相关指数数据展开

# 忽略某些模块的提示信息 import warnings warnings.filterwarnings("ignore") # 在全局配置中添加RQData账号信息 import rqdatac as rq from typing import List import pandas as pd import numpy as np import re from datetime import datetime, timedelta,tim…

论文精读:Root Cause Analysis in Microservice Using Neural Granger Causal Discovery

Root Cause Analysis in Microservice Using Neural Granger Causal Discovery 摘要 微服务架构因其可扩展性、维护性和灵活性在IT运营中得到广泛应用,但系统故障时SREs难以确定根本原因。以往研究使用结构化学习方法建立因果关系图,但忽略了时间序列数据的时间顺序,无法利…

DCU异构程序——GEMM

目录 一、概述 二、程序实现 三、编译运行 一、概述 HIP属于显式编程模型,需要在程序中明确写出并行控制语句,包括数据传输、核函数启动等。核函数是运行在DCU上的函数,在CPU端运行的部分称为主机端(主要是执行管理和启动&…

机器学习期末复习知识点

Apriori算法 基本思想: 1 找到所有的频繁项集 2 由频繁项集找到所有的强关联规则 目标: 给定一个事务集T,关联规则挖掘的目标是找到所有规则满足以下要求: 1 support>minsup threshold 2 confidence>minconf thresho…