AlphaFold3 的 ExponentialMovingAverage (EMA) 类,用于维护神经网络模型参数的指数加权移动平均。它可以在训练过程中对模型的参数进行平滑处理,以减缓参数更新的波动,帮助提升模型的泛化能力。
主要功能
- EMA 通过对每个参数的移动平均来稳定模型的训练过程。在每一步,参数的拷贝(
copy
)通过公式:
来更新。这个公式通过给新值和历史值加权来产生平滑效果。
源代码:
from collections import OrderedDict
import copy
import torch
import torch.nn as nn
from src.utils.tensor_utils import tensor_tree_mapclass ExponentialMovingAverage:"""Maintains moving averages of parameters with exponential decayAt each step, the stored copy `copy` of each parameter `param` isupdated as follows:`copy = decay * copy + (1 - decay) * param`where `decay` is an attribute of the ExponentialMovingAverage object."""def __init__(self, model: nn.Module, decay: float):"""Args:model:A torch.nn.Module whose parameters are to be trackeddecay:A value (usually close to 1.) by which updates areweighted as part of the above formula"""super(ExponentialMovingAverage, self).__init__()clone_param = lambda t: t.clone().detach()self.params = tensor_tree_map(clone_param, model.state_dict())self.decay = decayself.device = next(model.parameters()).devicedef to(self, device):self.params = tensor_tree_map(lambda t: t.to(device), self.params)self.device = devicedef _update_state_dict_(self, update, state_dict):with torch.no_grad():for k, v in update.items():stored = state_dict[k]if not isinstance(v, torch.Tensor):self._update_state_dict_(v, stored)else:diff = stored - vdiff *= 1 - self.decaystored -= diffdef update(self, model: torch.nn.Module) -> None:"""Updates the stored parameters using the state dict of the pr