详解L1和L2正则化

news/2025/3/5 2:03:51/

大纲:

  • L1和L2的区别以及范数相关知识
  • 对参数进行L1和L2正则化的作用与区别
  • pytorch实现L1与L2正则化
  • 对特征进行L2正则化的作用

L1和L2的区别以及范数

  使用机器学习方法解决实际问题时,我们通常要用L1或L2范数做正则化(regularization),从而限制权值大小,减少过拟合风险,故其又称为权重衰减。特别是在使用梯度下降来做目标函数优化时。

L1和L2的区别
在机器学习中,

  • L1范数(L2 normalization)是指向量中各个元素绝对值之和,通常表述为 ∥ w i ∥ 1 \|\boldsymbol{w_i}\|_1 wi1,线性回归中使用L1正则的模型也叫Lasso regularization
    比如 向量A=[1,-1,3], 那么A的L1范数为 |1|+|-1|+|3|.

  • L2范数指权值向量w中各个元素的平方和然后再求平方根(可以看到Ridge回归的L2正则化项有平方符号),通常表示为 ∥ w i ∥ 2 \|\boldsymbol{w_i}\|_2 wi2, 线性回归中使用L2正则的模型又叫岭回归(Ringe regularization)。

简单总结一下就是:

  • L1范数: 为x向量各个元素绝对值之和。
  • L2范数: 为x向量各个元素平方和的1/2次方,L2范数又称Euclidean范数或者Frobenius范数
  • Lp范数: 为x向量各个元素绝对值p次方和的1/p次方.

下图为p从无穷到0变化时,三维空间中到原点的距离(范数)为1的点构成的图形的变化情况。以常见的L-2范数(p=2)为例,此时的范数也即欧氏距离,空间中到原点的欧氏距离为1的点构成了一个球面
在这里插入图片描述

参数正则化作用

  • L1: 为模型加入先验, 简化模型, 使权值稀疏,由于权值的稀疏,从而过滤掉一些无用特征,防止过拟合
  • L2: 根据L2的特性,它会使得权值减小,即使平滑权值,一定程度上也能和L1一样起到简化模型,加速训练的作用,同时可防止模型过拟合

关于为什么L1会使得权重稀疏,而L2会使得权值平滑,可以参考知乎上一位答主的台大林轩田老师人工智能基石课笔记,从凸优化,梯度更新,概率分布三个角度诠释L1和L2正则化的原理和区别。我把笔记搬运到这:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

pytorch实现L1与L2正则化

网上很多关于L2和L1正则化的对象都是针对参数的,或者说权重,即权重衰减,可以用pytorch很简单的实现L2惩罚:

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

如上,weight_decay参数即为L2惩罚项前的系数
举个栗子,对模型中的某些参数进行惩罚时

#定义一层感知机
net = nn.Linear(num_inputs, 1)
#自定义参数初始化
nn.init.normal_(net.weight, mean=0, std=1)
nn.init.normal_(net.bias, mean=0, std=1)
optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减,惩罚项前的系数为wd
optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 不对偏差参数衰减

而对于L1正则化或者其他的就比较麻烦了,因为pytorch优化器只封装了L2惩罚功能,参考pytorch实现L2和L1正则化regularization的方法

class Regularization(torch.nn.Module):def __init__(self,model,weight_decay,p=2):''':param model 模型:param weight_decay:正则化参数:param p: 范数计算中的幂指数值,默认求2范数,当p=0为L2正则化,p=1为L1正则化'''super(Regularization, self).__init__()if weight_decay <= 0:print("param weight_decay can not <=0")exit(0)self.model=modelself.weight_decay=weight_decayself.p=pself.weight_list=self.get_weight(model)self.weight_info(self.weight_list)def to(self,device):'''指定运行模式:param device: cude or cpu:return:'''self.device=devicesuper().to(device)return selfdef forward(self, model):self.weight_list=self.get_weight(model)#获得最新的权重reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)return reg_lossdef get_weight(self,model):'''获得模型的权重列表:param model::return:'''weight_list = []for name, param in model.named_parameters():if 'weight' in name:weight = (name, param)weight_list.append(weight)return weight_listdef regularization_loss(self,weight_list, weight_decay, p=2):'''计算张量范数:param weight_list::param p: 范数计算中的幂指数值,默认求2范数:param weight_decay::return:'''# weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)# reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)# weight_decay=torch.FloatTensor([weight_decay]).to(self.device)# reg_loss=torch.FloatTensor([0.]).to(self.device)reg_loss=0for name, w in weight_list:l2_reg = torch.norm(w, p=p)reg_loss = reg_loss + l2_regreg_loss=weight_decay*reg_lossreturn reg_lossdef weight_info(self,weight_list):'''打印权重列表信息:param weight_list::return:'''print("---------------regularization weight---------------")for name ,w in weight_list:print(name)print("---------------------------------------------------")

class Regularization的使用


# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("-----device:{}".format(device))
print("-----Pytorch version:{}".format(torch.__version__))weight_decay=100.0 # 正则化参数model = my_net().to(device)
# 初始化正则化
if weight_decay>0:reg_loss=Regularization(model, weight_decay, p=2).to(device)
else:print("no regularization")criterion= nn.CrossEntropyLoss().to(device) # CrossEntropyLoss=softmax+cross entropy
optimizer = optim.Adam(model.parameters(),lr=learning_rate)#不需要指定参数weight_decay# train
batch_train_data=...
batch_train_label=...out = model(batch_train_data)# loss and regularization
loss = criterion(input=out, target=batch_train_label)
if weight_decay > 0:loss = loss + reg_loss(model)
total_loss = loss.item()# backprop
optimizer.zero_grad()#清除当前所有的累积梯度
total_loss.backward()
optimizer.step()

特征正则化作用

上面介绍了对于权重进行正则化的作用以及具体实现,其实在很多模型中,也会对特征采用L2归一化,有的时候在训练模型时,经过几个batch后,loss会变成nan,此时,如果你在特征后面加上L2归一化,可能可以很好的解决这个问题,而且有时会影响训练的效果,深有体会。
L2正则的原理比较简单,如下公式:
y = x i ∑ i = 0 D x i 2 \boldsymbol{y} = \frac{\boldsymbol{x_i}}{\sum_{i=0}^D\boldsymbol{{x_i}}^2 } y=i=0Dxi2xi
其中D为向量的长度,经过l2正则后 x i \boldsymbol{x_i} xi向量的元素平方和等于1

python实现

def l2norm(X, dim=-1, eps=1e-12):"""L2-normalize columns of X"""norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + epsX = torch.div(X, norm)return X

在SSD目标检测的conv4_3层便使用了L2Norm

对特征进行L2正则的具体作用如下:

  • 防止梯度消失或者梯度爆炸
  • 统一量纲,加快模型收敛

参考:

机器学习中L1和L2的直观理解
几种范数的介绍


http://www.ppmy.cn/news/546473.html

相关文章

L1正则化和L2正则化

在机器学习以及深度学习中我们经常会看到正则化这一名词&#xff0c;下面就浅谈一下什么是正则化&#xff1f;以及正则化的意义所在&#xff1f; 一、什么是正则化&#xff1f; 正则化项 (又称惩罚项)&#xff0c;惩罚的是模型的参数&#xff0c;其值恒为非负 λ是正则化系数&…

L1和L2简单易懂的理解

一、正则化&#xff08;Regularization&#xff09; 机器学习中几乎都可以看到损失函数后面会添加一个额外项&#xff0c;常用的额外项一般有两种&#xff0c;一般英文称作ℓ1ℓ1-norm和ℓ2ℓ2-norm&#xff0c;中文称作L1正则化和L2正则化&#xff0c;或者L1范数和L2范数。 L1…

L1与L2分别服从什么分布?

L1是拉普拉斯分布&#xff0c;L2是高斯分布。 正则化是一种回归的形式&#xff0c;它将系数估计&#xff08;coefficient estimate&#xff09;朝零的方向进行约束、调整或缩小。也就是说&#xff0c;正则化可以在学习过程中降低模型复杂度和不稳定程度&#xff0c;从而避免过…

机器学习中的L1与L2正则化图解

文章来源于SAMshare &#xff0c;作者flora &#x1f699;正则项 正则项的作用—防止模型过拟合&#xff0c;正则化可以分为L1范数正则化与L2范数正则化 &#x1f699;L1 AND L2范数 范数&#xff1a; 范数其实在 [0,∞)范围内的值&#xff0c;是向量的投影大小 在机器学习中一…

机器学习之L1、L2的区别与相关数学基础知识

机器学习数学基础概念、知识汇总&#xff08;线代&#xff09; 数学概念映射与函数线性与非线性空间线性空间&#xff08;向量空间&#xff09;向量基矩阵范数L-p范数L-0范数L-1范数L-2范数L-∞范数机器学习中的正则化L1正则化L2正则化L1与L2的区别 数学概念 本文将总结在机器…

金蝶云星空财务账套数据库中了.locked勒索病毒的解密步骤和预防方式

最近&#xff0c;金蝶云星空财务账套的数据库遭到了一次严重的勒索病毒攻击&#xff0c;导致数据库中重要数据被加密。这种攻击对企业来说是一种巨大的威胁&#xff0c;因为数据是企业的核心资产之一。而此次攻击的病毒为.locked后缀勒索病毒&#xff0c;而locked勒索病毒在本月…

RT-Thread-05-空闲线程和两个常用的钩子函数

空闲线程和两个钩子函数 空闲线程是一个比较特殊的系统线程&#xff0c;它具备最低优先级&#xff0c;当系统中无其他就绪线程可运行时&#xff0c;调度器将调度到空闲线程&#xff1b;空闲线程还负责一些系统资源回收以及将一些处于关闭状态的线程从线程调度列表中移除&#x…

Linux基础_1

目录 一、用户登录 1、root用户 2、普通&#xff08;非特权&#xff09;用户 二、终端terminal 1、终端类型 2、查看当前的终端设备 三、交互式接口 1、概念&#xff1a;启动终端后&#xff0c;在终端设备附加一个交互式应用程序 2、类型 3、什么是Shell 4、各种She…