机器学习(1)——线性回归、线性分类与梯度下降

embedded/2024/10/16 0:19:37/

文章目录

  • 线性回归
  • 线性分类
    • 线性可分数据
    • 线性不可分数据
    • 逻辑回归
    • 支持向量机
  • 梯度下降
    • 批量梯度下降
    • 随机梯度下降
    • 批量随机梯度下降

线性回归

概述:

在一元线性回归中,我们假设目标变量y与特征变量x存在线性关系,模型表达式为:
y = W 0 + W 1 x 1 + W 2 x 2 + ⋯ + W n x n + ϵ y=W_0+W_1x_1+W_2x_2+\cdots+W_nx_n+\epsilon y=W0+W1x1+W2x2++Wnxn+ϵ
其中:

  • W 0 W_0 W0是截距(bias)
  • W 1 , W 2 , . . . , W n W_1,W_2,...,W_n W1,W2,...,Wn是回归系数(权重)
  • ε是噪声或误差项

对于多个样本的情况,可以将特征表示为矩阵X,目标值为向量y。在这种情况下,线性回归模型可以写作:
y ^ = X W \hat{y}=XW y^=XW
其中:

  • X是输入特征矩阵(包含所有样本)
  • W是权重向量
  • y ^ \hat{y} y^是模型的预测值向量

损失函数(MSE):

为了让模型拟合数据,通常我们会使用均方误差(MSE)作为损失函数,来度量模型的预测值 y ^ \hat{y} y^与实际值y之间的差异:
J ( W ) = 1 2 m ∑ i = 1 m ( y ^ i − y i ) 2 = 1 2 m ( X W − y ) T ( X W − y ) J(W)=\frac1{2m}\sum_{i=1}^m(\hat{y}_i-y_i)^2=\frac1{2m}(XW-y)^T(XW-y) J(W)=2m1i=1m(y^iyi)2=2m1(XWy)T(XWy)
其中m是样本数

线性回归的闭式解: 指通过直接求解方程组,得到回归模型的参数(即权重向量)的解析解,而不需要通过迭代的优化算法(如梯度下降)来找到最优解

闭式解的推导:

最小化损失函数 J ( W ) J(W) J(W)以找到最优权重W。通过对 W W W求导并让导数为0,可以得到线性回归的解析解。首先对损失函数求导:
∂ J ( W ) ∂ W = 1 m X T ( X W − y ) \frac{\partial J(W)}{\partial W}=\frac1mX^T(XW-y) WJ(W)=m1XT(XWy)
将导数设置为0,求解 W W W
X T ( X W − y ) = 0 X T X W = X T y \begin{gathered} X^T(XW-y)=0 \\ X^TXW=X^Ty \end{gathered} XT(XWy)=0XTXW=XTy
可以通过矩阵求逆的方式得到 W W W
W = ( X T X ) − 1 X T y W=(X^TX)^{-1}X^Ty W=(XTX)1XTy
这就是线性回归的闭式解公式

闭式解的核心思想:

  • 直接求解:通过解析方法一次性求出最优权重 W W W,不需要像梯度下降一样逐步优化
  • 线性代数运行:通过矩阵转置、乘法和求逆等线性代数运算实现
  • 适用场景:对于小规模数据集、闭式解可以快速得到结果。然而当数据量非常大时,计算 X T X X^TX XTX的逆矩阵可能非常耗时,因此在大数据集上通常采用梯度下降等数据优化方法

优点与缺点:

  • 优点:直接得到最优解,计算速度快(适合小数据集
  • 缺点:对于高维度数据集,矩阵求逆的计算复杂度较高 O ( n 3 ) O(n^3) O(n3),在数据量过大时不适用

线性分类

概述: 线性分类器是基于线性决策边界进行分类的模型,形式上它会学到一个权重向量W和一个偏置b,其决策规则可以表示为:
f ( x ) = w T x + b f(\mathbf{x})=\mathbf{w}^T\mathbf{x}+b f(x)=wTx+b
在这种情况下,分类是根据f(X)的符号来进行的:

  • 如果f(x) > 0,则将数据点分类为正类
  • 否则为负类

这种方式只输出一个硬分类的结果,没有给出分类的概率

线性可分数据

概述: 指的是数据集中的不同类别可以通过一条直线(在二维空间中)或一个超平面(在高维空间中)完全分开,没有任何重叠或错误分类

特点:

  • 可以找到一个线性决策边界(如一条直线或一个超平面),使得数据集中所有点都可以准确分到正确的类别
  • 这种类比的数据适合使用线性分类器,如感知器、线性支持向量机(SVM)等

二维平面中的例子:

类别 A: (蓝色点)   类别 B: (红色点)蓝   蓝   蓝   蓝(直线)
红   红   红   红

在这种情况下,直线可以完全分开这两类点,没有任何交错

线性不可分数据

概述: 是指数据集中不同类别的点不能通过一条直线(或超平面)来完全分开,一些数据会落在错误的边界一侧,导致无法完美分类

特点:

  • 没有单一的线性边界可以准确分隔数据类别
  • 线性分类器在这种情况下表现不佳,因为它们依赖于线性分界线
  • 处理线性不可分数据的常用方法包括使用非线性模型(如核化支持向量机、决策树)或者对特征进行转换,,使数据在更高维空间中线性可分

二维平面中的例子:

类别 A: (蓝色点)   类别 B: (红色点)蓝   红   蓝   红
红   蓝   红   蓝

在这种情况下,无论如何放置一条直线,总会有一部分点被错误分类

解决线性不可分问题:

  • 引入非线性分类器:如使用核支持向量机(SVM),将数据映射到高维空间,使其在高维空间中线性可分。
  • 增加特征:通过添加多项式特征或交互特征,可以在输入空间中创建一个更复杂的模型。
  • 使用核技巧 (Kernel Trick):这是 SVM 的一个重要特性,通过核函数将低维数据映射到高维空间,使原本线性不可分的数据在高维空间中变得线性可分。

逻辑回归

概述: 是一种用于二分类问题的线性模型,尽管名字里有回归,它实际上用于分类任务。可以说是线性分类的一种特例,但它采用了概率的方式进行分类决策。

核心思想:

逻辑回归的目标是通过学习到的模型预测某个输入属于某个类别的概率。其基本形式是将线性回归的输出通过Sigmoid函数转换为一个介于0到1之间的概率值

  • 线性部分:给定一个输入向量X和模型参数W,线性部分的输出为:
    z = W T X + b z=W^TX+b z=WTX+b
    这个W是权重向量,b是偏置

  • Sigmoid函数:将线性输出z转换为概率值p:
    p = 1 1 + e − z p=\frac1{1+e^{-z}} p=1+ez1
    这个p表示预测结果为正类的概率

  • 损失函数:逻辑回归的损失函数通常是交叉熵损失,用于评估模型预测的概率分布和实际标签之间的差异:
    J ( W ) = − 1 m ∑ i = 1 m [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] J(W)=-\frac1m\sum_{i=1}^m\left[y_i\log(\hat{y}_i)+(1-y_i)\log(1-\hat{y}_i)\right] J(W)=m1i=1m[yilog(y^i)+(1yi)log(1y^i)]
    其中, y ^ i \hat{y}_i y^i是第i个样本的预测概率, y i y_i yi是实际标签,m是样本数

  • Logistic回归的梯度公式:
    ∇ L ( W ) = X T ( sigmoid ( X W ) − y ) \nabla L(W)=X^T(\text{sigmoid}(XW)-y) L(W)=XT(sigmoid(XW)y)

训练逻辑回归的过程: 通过最小化损失函数的值,调整权重W和偏置b。最常用的优化算法是梯度下降,其中包括不同的变体。

支持向量机

概述: Support Vector Machine是一种更为强大的分类算法,特别适合于线性不可分的数据集。SVM的目标是在特征空间中找到一个最优的决策边界(超平面),并且它有一个非常独特的特点:最大化分类边界的间隔

原理:

  • 超平面:SVM在特征空间中寻找一个超平面,将数据点分类。对于线性可分数据,超平面的方程是:
    w T x + b = 0 w^Tx+b=0 wTx+b=0

  • 最大化间隔:SVM不仅寻找一个可以分开数据的超平面,还要找到那个离两类数据点最远的超平面,确保间隔最大化。这被称为最大化分类边界的间隔(Margin)。这样可以增强模型的鲁棒性,减少过拟合

  • 支持向量:距离决策边界最近的那些数据点被称为支持向量。这些点对决策边界有最重要的影响

损失函数:

  • 合页损失函数(Hinge Loss):是用于分类函数的损失函数,用来惩罚错误分类分类边界附件的样本
    L ( y i , f ( x i ) ) = max ⁡ ( 0 , 1 − y i f ( x i ) ) L(y_i,f(x_i))=\max(0,1-y_if(x_i)) L(yi,f(xi))=max(0,1yif(xi))
    其中:

    • y i ∈ { − 1 , 1 } y_i \in \{-1, 1\} yi{1,1} 是样本i的真实标签(SVM通常处理二分类问题)
    • f ( x i ) = w ⊤ x i + b f(x_i)=\mathbf{w}^\top x_i+b f(xi)=wxi+b 是模型对样本 x i x_i xi的预测结果,表示超平面w和 x i x_i xi的内积再加上偏置b
    • 1 − y i f ( x i ) 1-y_if(x_i) 1yif(xi) 是衡量样本离决策边界的距离
    • 如果样本被正确分类且距离边界大于 1,损失为 0;否则,损失随着样本距离边界的接近或错误分类而增加。
  • 正则化项:SVM 模型的目标是找到能够最大化分类间距(margin)的超平面。因此,为了平衡分类误差和间距的最大化,损失函数通常还包括一个正则化项,用来控制模型的复杂度(即防止过拟合)。常见的正则化项是L2 正则化,其形式为:
    R ( w ) = 1 2 ∥ w ∥ 2 R(\mathbf{w})=\frac12\|\mathbf{w}\|^2 R(w)=21w2

  • 总损失函数:
    J ( w , b ) = 1 2 ∥ w ∥ 2 + C ∑ i = 1 m max ⁡ ( 0 , 1 − y i ( w ⊤ x i + b ) ) J(\mathbf{w},b)=\frac12\|\mathbf{w}\|^2+C\sum_{i=1}^m\max(0,1-y_i(\mathbf{w}^\top x_i+b)) J(w,b)=21w2+Ci=1mmax(0,1yi(wxi+b))

    • 其中:
      • C 是一个超参数,用于控制正则化项和合页损失之间的权衡。较大的 C 会减少分类错误,但可能导致过拟合;较小的 C 会增加容错性,防止过拟合。
      • m 是训练样本的数量。

梯度公式:

  • 对权重向量W的梯度:
    ∂ J ( w , b ) ∂ w = w − C ∑ i ∈ M y i x i \frac{\partial J(\mathbf{w},b)}{\partial\mathbf{w}}=\mathbf{w}-C\sum_{i\in\mathcal{M}}y_ix_i wJ(w,b)=wCiMyixi

  • 对偏置b的梯度:
    ∂ J ( w , b ) ∂ b = − C ∑ i ∈ M y i \frac{\partial J(\mathbf{w},b)}{\partial b}=-C\sum_{i\in\mathcal M}y_i bJ(w,b)=CiMyi

核技巧(kernel Trick): 当数据无法通过线性超平面分割时,SVM使用核技巧将数据映射到高维空间。常见的核函数包括:

  • 多项式核(Polynomial Kernel):将原始数据通过多项式映射到高维空间
  • 高斯核/径向基核(RBF Kernel):将数据点投影到无穷维空间,使得非线性数据在高维空间中变得线性可分

梯度下降

是一种用于优化线性回归和线性分类模型的迭代方法,通过计算损失函数的梯度,并沿着梯度的反方向迭代更新参数 W W W,逐步逼近最优解。梯度下降有三种主要变体:

  • 批量梯度下降(Batch Gradient Descent, BGD)
  • 随机梯度下降(Stochastic Gradient Descent, SGD)
  • 随机批量梯度下降(Mini-Batch Gradient Descent, MBGD)

样本数的对梯度公式的影响:

  • 如果不除以样本数,计算得到的梯度是累计的梯度,也就是每个样本的误差对权重的累积影响
  • 如果除以样本数,计算得到的是平均梯度,每次更新会使用平均误差对权重进行更新

参数更新公式:
W = W − η ∇ L ( W ) W=W-\eta\nabla L(W) W=WηL(W)
其中:

  • W W W是参数向量
  • η \eta η是学习率,控制更新的步长
  • 损失函数 J ( W ) J(W) J(W),在线性回归中有介绍,不同的模型可以选取不同的损失函数
  • ∇ L ( W ) \nabla L(W) L(W)是损失函数 J ( W ) J(W) J(W)对参数W的梯度(也就是求导,一般来说,如果损失函数是每个样本损失的平均值,也就是除以了样本数m,那损失的函数的梯度就不再需要除以样本数m了

批量梯度下降

概述: 在每次迭代中,使用整个训练集来计算梯度

过程:

  1. 首先,初始化模型参数
  2. 定义损失函数 J ( W ) J(W) J(W)
  3. 求解损失函数的导数,也就是损失函数的梯度函数 L ( W ) L(W) L(W)
  4. 根据上面给的参数更新公式更新W
  5. 重复迭代

优点:

  • 更新稳定,避免了由样本噪声引起的波动
  • 收敛到全局最优解,梯度方向更精确

缺点:

  • 计算成本高
  • 不适合大规模数据集

随机梯度下降

概述: 在每次迭代中,只使用一个样本计算梯度

SGD优缺点:

  • 优点:
    • 效率高:对于大规模数据集,SGD不需要每次都遍历整个数据集,它每次只对一个样本进行更新,使得计算更快
    • 内存友好:由于它只需要处理一个样本,内存消耗相对较低,适合处理大数据
    • 在线学习:SGD可以随着新数据的到来在线更新模型, 而不需要每次都从头开始训练
  • 缺点:
    • 噪声较大:由于每次更新使用的是单个样本,更新方向可能不是全局最优,因此SGD的收敛路径往往比较噪声且不稳定
    • 需要调整学习率:学习率 η \eta η的选择至关重要。如果学习率过大,参数更新可能会错过最优点;如果学习率过小,收敛速度将非常慢

学习率衰减策略:

为了解决噪声问题,常见的做法是在训练过程中逐渐减低学习率,这种方法可以在训练初期进行较大步长的更新,使得模型快速接近最优解,而在后期逐渐减小步长,使得模型在最优解附件收敛

学习率衰减公式的常见形式是:
KaTeX parse error: Expected 'EOF', got '_' at position 40: …1 + \text{decay_̲rate} \cdot t}
其中:

  • η 0 \eta_0 η0是初始学习率
  • t t t是当前的迭代次数
  • KaTeX parse error: Expected 'EOF', got '_' at position 12: \text{decay_̲rate}是学习率的衰减系数

伪代码:

初始化 W
for epoch in range(num_epochs):for i in range(m):  # m是样本数量随机选取一个样本 (x_i, y_i)计算该样本的梯度: grad = x_i * (x_i W - y_i)取负梯度方向更新参数: W = W - η * grad记录训练集或验证集的损失

批量随机梯度下降

概述: 每次使用一小批随机样本计算梯度

过程:

  1. 初始化权重W和b

  2. 选择批量大小B(比如32,64等

  3. 每次迭代时,从训练集中随机抽取一批样本,计算该批样本上的损失函数梯度,然后更新权重:
    W = W − η ∇ L ( W ) W=W-\eta\nabla L(W) W=WηL(W)
    其中 η \eta η是学习率, η ∇ L ( W ) \eta\nabla L(W) ηL(W)是对权重的梯度

优点:

  • 在每次更新时引入随机性,避免陷入局部最优解
  • 更新效率较高,能在大规模数据集上加速训练
  • 在批量计算中还可以利用并行化处理,进一步提高效率

http://www.ppmy.cn/embedded/118686.html

相关文章

前端面试总结2

1.计算圆的周长 class Shape{ constructor(radius10){ this.radiusradius; } diameter(){ return diameter2*this.radius; } perimeter(){ const circumference2*Math.PI*this.radius; console.log(this.diameter,circumference); return circumference; }} const cirlenew sh…

使用docker形式部署prometheus+alertmanager+钉钉告警

一、拉取所需要的镜像 docker pull prom/node-exporter docker pull grafana/grafana docker pull prom/prometheus docker pull prom/alertmanager 其中 prom/node-exporter:用于收集主机系统信息和指标的 grafana/grafana:是一个用于可视化和分…

梳理软件需求,期望不合理问题如何解决?

在梳理软件需求时,往往会遇到:客户会提出过高或不切实际的期望,如要求功能过于复杂或时间过于紧迫,这些期望可能不切实际,导致开发团队难以满足,从而影响项目进度、成本和质量。 面对这些期望不合理问题&a…

【vue3】登录功能怎么实现?

无论是手机端还是pc端,几乎都包含登录注册方面功能,今天总结登录注册功能。 实现功能 注册 密码加密 登录 校验 token处理 1.环境搭建运行(nodeexpressmongodb) 在目录里安装express和mongoose,并在根目录创建server.j…

用上这10条神指令(prompt),让ChatGPT快速写出优质高分科研论文

大家好,感谢关注。我是七哥,一个在高校里不务正业,折腾学术科研AI实操的学术人。关于使用ChatGPT等AI学术科研的相关问题可以和作者七哥(yida985)交流,多多交流,相互成就,共同进步,为大家带来最酷最有效的智能AI学术科研写作攻略。 这篇文章将给大家分享十个科研论文写…

JAVA:Spring Boot 集成 Tess4J 实现文字识别的技术指南

请关注微信公众号:拾荒的小海螺 博客地址:http://lsk-ww.cn/ 1、简述 图片文字识别(Optical Character Recognition, OCR)是一项将图像中的文字转换为可编辑文本的技术。Tess4J 是 Tesseract OCR 引擎的一个 Java 封装&#xff…

MySQL 主键索引等值查询加什么锁?

这一期介绍读已提交、可重复读两个隔离级别下,主键索引等值查询的加锁情况。 作者:操盛春,爱可生技术专家,公众号『一树一溪』作者,专注于研究 MySQL 和 OceanBase 源码。 爱可生开源社区出品,原创内容未经…

IM开发首选:WebSocket实现分频道广播的设计思路和实现难点分析

IM开发首选:WebSocket实现分频道广播的设计思路和实现难点分析 即时通讯(Instant Messaging,简称IM)应用在现代社会中已经无处不在。无论是个人聊天、群组讨论,还是企业内部通信,IM都发挥着至关重要的作用。…