pytorch梯度截断之torch.nn.utils.clip_grad_norm_

news/2025/1/13 8:06:49/

当深度学习网络层数逐渐增加时,反向传播过程中链式法则里的梯度连乘项数也会随之增加,容易引起梯度消失和梯度爆炸。对于梯度爆炸,除了BN、shortcut、更换激活函数及权重正则化外,还有一个解决方法就是梯度剪裁,即设置一个梯度大小的上限。

torch.nn.utils.clip_grad_norm_

使用方法

在损失函数反向传播后(loss.backward())及参数更新前(optimizer.step())

函数定义

def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,error_if_nonfinite: bool = False) -> torch.Tensor:r"""Clips gradient norm of an iterable of parameters.The norm is computed over all gradients together, as if they wereconcatenated into a single vector. Gradients are modified in-place.Args:parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or asingle Tensor that will have gradients normalizedmax_norm (float or int): max norm of the gradientsnorm_type (float or int): type of the used p-norm. Can be ``'inf'`` forinfinity norm.error_if_nonfinite (bool): if True, an error is thrown if the totalnorm of the gradients from :attr:`parameters` is ``nan``,``inf``, or ``-inf``. Default: False (will switch to True in the future)Returns:Total norm of the parameter gradients (viewed as a single vector)."""

parameters:某组网络模型参数
max_norm:该组网络模型参数梯度的范数最大值
norm_type:范数类型,默认值为2

计算total_norm

如果norm_type为inf,则取输入的所有参数梯度范数中的最大值作为total_norm,

norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))

其他情况,则计算输入的所有参数梯度,stack成新向量,再对向量计算范数,作为total_norm,

total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)

计算clip_coef及clip_coef_clamped

比较max_norm与total_norm,

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

若比值小于1,即max_norm<total_norm,则将参数梯度乘以clip_coef_clamped;如果max_norm>total_norm,即没有溢出预设上限,则不对梯度进行修改,

p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))

http://www.ppmy.cn/news/666537.html

相关文章

苹果库乐队怎么玩_iPhone实用技巧:怎么将抖音上的背景音乐制作成手机铃声

抖音已经成为目前热门的APP,上面有很多优秀的音乐作品。当我们遇上了喜欢的音乐,怎么做成手机铃声呢? 需要安装的APP 1.抖音短视频APP 2.库乐队GarageBand APP(苹果自带,如果删除了,可重新在App Store里下载) 3.音乐剪辑APP(在App Store里下载) 详细步骤 1.在抖音APP里面找…

苹果6p计算机在哪里设置方法,苹果手机怎么设置铃声【图文教程,不用电脑,1分钟完成】...

苹果手机好用&#xff0c;大家都知道&#xff0c;凭借着全封闭的生态系统&#xff0c;IOS的流畅性不是其他任何手机操作系统能相比的&#xff0c;苹果手机的通用铃声大家应该都很熟悉了&#xff0c;已经到了听铃声就知道用的手机是苹果手机的地步&#xff01;先来一首变调版的换…

方向导数和梯度

理性认识的三个阶段&#xff1a;定义、判断、推理。 有位博主说过&#xff0c;数学中&#xff0c;定义占60%的内容。 方向导数定义如下&#xff1a; 注意的一点是&#xff1a; 该处的alpha&#xff0c;beta角度关系是alpha beta pi/2。t*cos alpha &#xff0c;t * cos …

修理耳机完全记录

耳机的复活&#xff08;全过程&#xff09; 喜欢听歌的童鞋当然应该有一副好耳塞&#xff0c;那样才能真正享受美妙的音乐。但是耳塞是相当易坏的&#xff0c;说不定哪一天就只有一个耳朵响了&#xff0c;但是捏一捏、动一动插头部位的引线&#xff0c;偶尔会都响一下&#xff…

Beats:Beats 入门教程 (二)

这篇文章是 “Beats 入门教程 &#xff08;一&#xff09;”的续篇。在上一篇文章&#xff0c;我们主要讲述了 Beats 的一些理论方面的知识。在这篇文章中&#xff0c;我们将具体展示如何使用 Filebeat 及 Metriceat 把数据导入到我们的 Elasticsearch 并对他们进行分析。 安装…

Beats:如何调试 Beats processors

在之前的 “Beats&#xff1a;Beats processors” 文章中&#xff0c;我详细地描述了如何使用 Beat 的 processors 对数据进行清洗。在很多情况下它是非常有用的一种方法。Beats 的 processors 有很多在 ingest pipeline 的 processors 中 以及 Logstash 的过滤器中都有相应的…

华为耳机5根线怎么接线图解_【技能】小白耳机维修入门--各种耳机插头接线图--耳机维修汇总贴...

声明&#xff1a;个人观点&#xff0c;不保证完全正确。 另如果有其他型号的接线图欢迎补充或指正&#xff0c;如留言不便可联系邮箱&#xff0c;此贴持续更新&#xff01; ning.pgqq.com 耳机插头分两种 3节耳机&#xff0c;就是耳机插头上分了3段&#xff0c;一般都是只能连…

维修蓝牙耳机修好了

蓝牙耳机都坏了有两个多月了吧&#xff0c;一直想修一下&#xff0c;却没有时间&#xff0c;充电没问题&#xff0c;开不了机。因为一开始不是开不了机&#xff0c;而是比较困难&#xff0c;按键好几次还不能开机&#xff0c;开了之后又关不了&#xff0c;一旦开机使用没有其它…