AF3 ExponentialMovingAverage类解读

embedded/2025/2/11 14:16:27/

AlphaFold3 的 ExponentialMovingAverage (EMA) 类,用于维护神经网络模型参数的指数加权移动平均。它可以在训练过程中对模型的参数进行平滑处理,以减缓参数更新的波动,帮助提升模型的泛化能力。

主要功能

  • EMA 通过对每个参数的移动平均来稳定模型的训练过程。在每一步,参数的拷贝(copy)通过公式:

来更新。这个公式通过给新值和历史值加权来产生平滑效果。

源代码:

from collections import OrderedDict
import copy
import torch
import torch.nn as nn
from src.utils.tensor_utils import tensor_tree_mapclass ExponentialMovingAverage:"""Maintains moving averages of parameters with exponential decayAt each step, the stored copy `copy` of each parameter `param` isupdated as follows:`copy = decay * copy + (1 - decay) * param`where `decay` is an attribute of the ExponentialMovingAverage object."""def __init__(self, model: nn.Module, decay: float):"""Args:model:A torch.nn.Module whose parameters are to be trackeddecay:A value (usually close to 1.) by which updates areweighted as part of the above formula"""super(ExponentialMovingAverage, self).__init__()clone_param = lambda t: t.clone().detach()self.params = tensor_tree_map(clone_param, model.state_dict())self.decay = decayself.device = next(model.parameters()).devicedef to(self, device):self.params = tensor_tree_map(lambda t: t.to(device), self.params)self.device = devicedef _update_state_dict_(self, update, state_dict):with torch.no_grad():for k, v in update.items():stored = state_dict[k]if not isinstance(v, torch.Tensor):self._update_state_dict_(v, stored)else:diff = stored - vdiff *= 1 - self.decaystored -= diffdef update(self, model: torch.nn.Module) -> None:"""Updates the stored parameters using the state dict of the pr

http://www.ppmy.cn/embedded/161346.html

相关文章

ctfshow-36D杯

ctfshow-36D杯 给你shell ($obj[secret] ! $flag_md5 ) ? haveFun($flag) : echo "here is your webshell: $shell_path"; 这是个弱比较,输入?give_me_shell前三个是0说明二进制小于1000000就是ASCII的64, 0-32是不可见或非打印字符&…

日志2025.2.9

日志2025.2.9 1.增加了敌人挥砍类型 2.增加了敌人的死亡状态 在敌人身上添加Ragdoll,死后激活布偶模式 public class EnemyRagdoll : MonoBehaviour { private Rigidbody[] rigidbodies; private Collider[] colliders; private void Awake() { rigidbodi…

C#Halcon窗体鼠标交互生成菜单

窗体鼠标交互生成菜单,移动鼠标作出相应的提示,并且可以进入相应事件。(一般可以应用到成品效果展示,或实战项目检测失败时,需做出人机交互选择时可应用,相对于按键交互,可以优化UI布局&#xf…

1.攻防世界 unserialize3(wakeup()魔术方法、反序列化工作原理)

进入题目页面如下 直接开审 <?php // 定义一个名为 xctf 的类 class xctf {// 声明一个公共属性 $flag&#xff0c;初始值为字符串 111public $flag 111;// 定义一个魔术方法 __wakeup()// 当对象被反序列化时&#xff0c;__wakeup() 方法会自动调用public function __wa…

轻量级服务器http-server

安装 sudo npm install http-server -g 运行 1. 直接去到要跑起来的目录&#xff0c;在终端输入 cd xxxx文件夹http-server //只输入http-server的话&#xff0c;更新了代码后&#xff0c;页面不会同步更新http-server -c-1 //同步更新页面http-server -a 127.0.0.1 -p 808…

Jetpack之ViewBinding和DataBinding的区别

DataBinding与ViewBinding推出的时间 DataBinding推出的时间比较早‌&#xff0c;ViewBinding是后来推出的。 DataBinding在2015年7月发布的Android Studio v1.3.0版本上引入&#xff0c;并在2016年4月Android Studio v2.0.0版本上正式支持‌。 ViewBinding则是在2019年Andr…

在微服务中,如何使用feign在各个微服务中进行远程调用

在微服务中&#xff0c;如何使用feign在不同微服务中进行远程调用 在微服务中&#xff0c;如何使用feign在不同微服务中进行远程调用 步骤&#xff1a; 第一步&#xff1a; 引入feign依赖 <dependency><groupId>org.springframework.cloud</groupId><…

DeepSeek R1 Distill Llama 70B(免费版)API使用详解

DeepSeek R1 Distill Llama 70B&#xff08;免费版&#xff09;API使用详解 在人工智能领域&#xff0c;随着技术的不断进步&#xff0c;各种新的模型和应用如雨后春笋般涌现。今天&#xff0c;我们要为大家介绍的是OpenRouter平台上提供的DeepSeek R1 Distill Llama 70B&…