在深度学习中,经常会使用指数移动平均模型(Exponential Moving Average Model,EMA)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
这里的平均是是一种给予近期数据更高权重的平均方法
EMA是一种用于平滑时间序列数据的技术。它通过对数据进行加权平均来减少噪音和波动,从而提取出数据的趋势。
在深度学习中,EMA 常常用于模型的参数更新和优化过程中。它可以帮助模型在训练过程中更稳定地收敛,并提高模型的泛化能力。
一、广义EMA
假设我们有 n 个数据:
- 普通的平均数:
- EMA: ,其中, 表示前 时刻的平均值 ,是加权权重值(一般设为0.9-0.999)。
Andrew Ng 在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 个时刻 值的平均。
普通的过去 时刻的平均是这样的:
类比EMA,可以发现当 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。
实际上,EMA计算时,过去 个时刻之前的数值平均会衰变到 的加权比例,证明如下。
如果将这里的 展开,可以得到:
其中,,代入可以得到 。
二、在深度学习的优化中的EMA
在深度学习中,EMA 常常用于以下两个方面:
- 参数更新:在模型训练过程中,通常会使用梯度下降等优化算法来更新模型的参数。而使用 EMA 更新参数时,可以通过计算参数的指数移动平均值来更新参数,从而减少参数更新的噪音和波动。
- 模型预测:在模型预测阶段,可以使用训练过程中得到的参数的指数移动平均值来进行预测。这样可以减少模型预测结果的波动,提高预测的稳定性。
在深度学习的优化过程中, 是时刻的模型权重weights, 是 时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练,而是用于后续的决策和评估。基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。
EMA通过对参数进行平滑处理,使得较新的参数值对应的权重较大,较旧的参数值对应的权重较小。这样可以更好地反映参数的变化趋势,并在模型训练中提供更稳定的更新。
下面是一种常见的使用EMA进行参数更新和优化的方法,称为EMA更新策略:
- 初始化模型参数:初始化模型的参数为初始值。
- 初始化EMA:将EMA的初始值设置为与模型参数相同的初始值。
- 迭代训练:对于每个训练迭代(epoch):
a. 计算梯度:根据训练数据和当前的模型参数,计算模型的梯度。
b. 更新参数:使用梯度下降或其他优化算法更新模型参数。
c. 更新EMA:更新EMA的值,将当前的模型参数与EMA的上一个值进行平滑处理。
d. 更新模型参数:将平滑后的EMA值作为新的模型参数值。
在预测阶段,可以使用指数移动平均模型来平滑模型参数,并基于平滑后的参数进行预测。
通过使用指数移动平均模型,在模型预测过程中,可以减少参数的波动,提高预测结果的稳定性。这有助于降低模型对噪音和异常值的敏感性,提高预测的准确性和鲁棒性。
三、EMA的代码实现
实现适用于任何类型模型的指数移动平均(EMA):
EMA权重将在验证期间使用,并与原始模型权重分开存储。 如何使用EMA:
- 有时,最后的EMA检查点可能不是最佳的,因为EMA权重的指标可能会随时间出现长期振荡。参见 https://github.com/rwightman/pytorch-image-models/issues/102
- 批量归一化(Batch Norm)层和可能的其他类型的归一化层不需要在最后更新。参见以下讨论: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 和 https://github.com/rwightman/pytorch-image-models/issues/224
- 对于目标检测,通常 SWA(随机权重平均)效果更好。参见 https://github.com/timgaripov/swa/issues/16
实现细节:
- 参见 Pytorch Lightning 中的 EMA:https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
- 在多 GPU 环境中,我们广播 EMA 权重和原始权重,以便在内存中只保留一份副本。
- 当将 EMA 权重存储在 CPU + 固定内存上时,这一点尤其重要,因为固定内存是有限资源。
- 此外,我们希望避免在非 0 级别的重复操作,以减少抖动并提高性能。
reference:
【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎
深度学习之指数移动平均模型(EMA)介绍_ema模型-CSDN博客
TDNetGen/README.md at main · tsinghua-fib-lab/TDNetGen · GitHub