当深度学习网络层数逐渐增加时,反向传播过程中链式法则里的梯度连乘项数也会随之增加,容易引起梯度消失和梯度爆炸。对于梯度爆炸,除了BN、shortcut、更换激活函数及权重正则化外,还有一个解决方法就是梯度剪裁,即设置一个梯度大小的上限。
torch.nn.utils.clip_grad_norm_
使用方法
在损失函数反向传播后(loss.backward())及参数更新前(optimizer.step())
函数定义
def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,error_if_nonfinite: bool = False) -> torch.Tensor:r"""Clips gradient norm of an iterable of parameters.The norm is computed over all gradients together, as if they wereconcatenated into a single vector. Gradients are modified in-place.Args:parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or asingle Tensor that will have gradients normalizedmax_norm (float or int): max norm of the gradientsnorm_type (float or int): type of the used p-norm. Can be ``'inf'`` forinfinity norm.error_if_nonfinite (bool): if True, an error is thrown if the totalnorm of the gradients from :attr:`parameters` is ``nan``,``inf``, or ``-inf``. Default: False (will switch to True in the future)Returns:Total norm of the parameter gradients (viewed as a single vector)."""
parameters:某组网络模型参数
max_norm:该组网络模型参数梯度的范数最大值
norm_type:范数类型,默认值为2
计算total_norm
如果norm_type为inf,则取输入的所有参数梯度范数中的最大值作为total_norm,
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
其他情况,则计算输入的所有参数梯度,stack成新向量,再对向量计算范数,作为total_norm,
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
计算clip_coef及clip_coef_clamped
比较max_norm与total_norm,
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
若比值小于1,即max_norm<total_norm,则将参数梯度乘以clip_coef_clamped;如果max_norm>total_norm,即没有溢出预设上限,则不对梯度进行修改,
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))