机器学习(1)——线性回归、线性分类与梯度下降

news/2024/12/21 5:45:00/

文章目录

  • 线性回归
  • 线性分类
    • 线性可分数据
    • 线性不可分数据
    • 逻辑回归
    • 支持向量机
  • 梯度下降
    • 批量梯度下降
    • 随机梯度下降
    • 批量随机梯度下降

线性回归

概述:

在一元线性回归中,我们假设目标变量y与特征变量x存在线性关系,模型表达式为:
y = W 0 + W 1 x 1 + W 2 x 2 + ⋯ + W n x n + ϵ y=W_0+W_1x_1+W_2x_2+\cdots+W_nx_n+\epsilon y=W0+W1x1+W2x2++Wnxn+ϵ
其中:

  • W 0 W_0 W0是截距(bias)
  • W 1 , W 2 , . . . , W n W_1,W_2,...,W_n W1,W2,...,Wn是回归系数(权重)
  • ε是噪声或误差项

对于多个样本的情况,可以将特征表示为矩阵X,目标值为向量y。在这种情况下,线性回归模型可以写作:
y ^ = X W \hat{y}=XW y^=XW
其中:

  • X是输入特征矩阵(包含所有样本)
  • W是权重向量
  • y ^ \hat{y} y^是模型的预测值向量

损失函数(MSE):

为了让模型拟合数据,通常我们会使用均方误差(MSE)作为损失函数,来度量模型的预测值 y ^ \hat{y} y^与实际值y之间的差异:
J ( W ) = 1 2 m ∑ i = 1 m ( y ^ i − y i ) 2 = 1 2 m ( X W − y ) T ( X W − y ) J(W)=\frac1{2m}\sum_{i=1}^m(\hat{y}_i-y_i)^2=\frac1{2m}(XW-y)^T(XW-y) J(W)=2m1i=1m(y^iyi)2=2m1(XWy)T(XWy)
其中m是样本数

线性回归的闭式解: 指通过直接求解方程组,得到回归模型的参数(即权重向量)的解析解,而不需要通过迭代的优化算法(如梯度下降)来找到最优解

闭式解的推导:

最小化损失函数 J ( W ) J(W) J(W)以找到最优权重W。通过对 W W W求导并让导数为0,可以得到线性回归的解析解。首先对损失函数求导:
∂ J ( W ) ∂ W = 1 m X T ( X W − y ) \frac{\partial J(W)}{\partial W}=\frac1mX^T(XW-y) WJ(W)=m1XT(XWy)
将导数设置为0,求解 W W W
X T ( X W − y ) = 0 X T X W = X T y \begin{gathered} X^T(XW-y)=0 \\ X^TXW=X^Ty \end{gathered} XT(XWy)=0XTXW=XTy
可以通过矩阵求逆的方式得到 W W W
W = ( X T X ) − 1 X T y W=(X^TX)^{-1}X^Ty W=(XTX)1XTy
这就是线性回归的闭式解公式

闭式解的核心思想:

  • 直接求解:通过解析方法一次性求出最优权重 W W W,不需要像梯度下降一样逐步优化
  • 线性代数运行:通过矩阵转置、乘法和求逆等线性代数运算实现
  • 适用场景:对于小规模数据集、闭式解可以快速得到结果。然而当数据量非常大时,计算 X T X X^TX XTX的逆矩阵可能非常耗时,因此在大数据集上通常采用梯度下降等数据优化方法

优点与缺点:

  • 优点:直接得到最优解,计算速度快(适合小数据集
  • 缺点:对于高维度数据集,矩阵求逆的计算复杂度较高 O ( n 3 ) O(n^3) O(n3),在数据量过大时不适用

线性分类

概述: 线性分类器是基于线性决策边界进行分类的模型,形式上它会学到一个权重向量W和一个偏置b,其决策规则可以表示为:
f ( x ) = w T x + b f(\mathbf{x})=\mathbf{w}^T\mathbf{x}+b f(x)=wTx+b
在这种情况下,分类是根据f(X)的符号来进行的:

  • 如果f(x) > 0,则将数据点分类为正类
  • 否则为负类

这种方式只输出一个硬分类的结果,没有给出分类的概率

线性可分数据

概述: 指的是数据集中的不同类别可以通过一条直线(在二维空间中)或一个超平面(在高维空间中)完全分开,没有任何重叠或错误分类

特点:

  • 可以找到一个线性决策边界(如一条直线或一个超平面),使得数据集中所有点都可以准确分到正确的类别
  • 这种类比的数据适合使用线性分类器,如感知器、线性支持向量机(SVM)等

二维平面中的例子:

类别 A: (蓝色点)   类别 B: (红色点)蓝   蓝   蓝   蓝(直线)
红   红   红   红

在这种情况下,直线可以完全分开这两类点,没有任何交错

线性不可分数据

概述: 是指数据集中不同类别的点不能通过一条直线(或超平面)来完全分开,一些数据会落在错误的边界一侧,导致无法完美分类

特点:

  • 没有单一的线性边界可以准确分隔数据类别
  • 线性分类器在这种情况下表现不佳,因为它们依赖于线性分界线
  • 处理线性不可分数据的常用方法包括使用非线性模型(如核化支持向量机、决策树)或者对特征进行转换,,使数据在更高维空间中线性可分

二维平面中的例子:

类别 A: (蓝色点)   类别 B: (红色点)蓝   红   蓝   红
红   蓝   红   蓝

在这种情况下,无论如何放置一条直线,总会有一部分点被错误分类

解决线性不可分问题:

  • 引入非线性分类器:如使用核支持向量机(SVM),将数据映射到高维空间,使其在高维空间中线性可分。
  • 增加特征:通过添加多项式特征或交互特征,可以在输入空间中创建一个更复杂的模型。
  • 使用核技巧 (Kernel Trick):这是 SVM 的一个重要特性,通过核函数将低维数据映射到高维空间,使原本线性不可分的数据在高维空间中变得线性可分。

逻辑回归

概述: 是一种用于二分类问题的线性模型,尽管名字里有回归,它实际上用于分类任务。可以说是线性分类的一种特例,但它采用了概率的方式进行分类决策。

核心思想:

逻辑回归的目标是通过学习到的模型预测某个输入属于某个类别的概率。其基本形式是将线性回归的输出通过Sigmoid函数转换为一个介于0到1之间的概率值

  • 线性部分:给定一个输入向量X和模型参数W,线性部分的输出为:
    z = W T X + b z=W^TX+b z=WTX+b
    这个W是权重向量,b是偏置

  • Sigmoid函数:将线性输出z转换为概率值p:
    p = 1 1 + e − z p=\frac1{1+e^{-z}} p=1+ez1
    这个p表示预测结果为正类的概率

  • 损失函数:逻辑回归的损失函数通常是交叉熵损失,用于评估模型预测的概率分布和实际标签之间的差异:
    J ( W ) = − 1 m ∑ i = 1 m [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] J(W)=-\frac1m\sum_{i=1}^m\left[y_i\log(\hat{y}_i)+(1-y_i)\log(1-\hat{y}_i)\right] J(W)=m1i=1m[yilog(y^i)+(1yi)log(1y^i)]
    其中, y ^ i \hat{y}_i y^i是第i个样本的预测概率, y i y_i yi是实际标签,m是样本数

  • Logistic回归的梯度公式:
    ∇ L ( W ) = X T ( sigmoid ( X W ) − y ) \nabla L(W)=X^T(\text{sigmoid}(XW)-y) L(W)=XT(sigmoid(XW)y)

训练逻辑回归的过程: 通过最小化损失函数的值,调整权重W和偏置b。最常用的优化算法是梯度下降,其中包括不同的变体。

支持向量机

概述: Support Vector Machine是一种更为强大的分类算法,特别适合于线性不可分的数据集。SVM的目标是在特征空间中找到一个最优的决策边界(超平面),并且它有一个非常独特的特点:最大化分类边界的间隔

原理:

  • 超平面:SVM在特征空间中寻找一个超平面,将数据点分类。对于线性可分数据,超平面的方程是:
    w T x + b = 0 w^Tx+b=0 wTx+b=0

  • 最大化间隔:SVM不仅寻找一个可以分开数据的超平面,还要找到那个离两类数据点最远的超平面,确保间隔最大化。这被称为最大化分类边界的间隔(Margin)。这样可以增强模型的鲁棒性,减少过拟合

  • 支持向量:距离决策边界最近的那些数据点被称为支持向量。这些点对决策边界有最重要的影响

损失函数:

  • 合页损失函数(Hinge Loss):是用于分类函数的损失函数,用来惩罚错误分类分类边界附件的样本
    L ( y i , f ( x i ) ) = max ⁡ ( 0 , 1 − y i f ( x i ) ) L(y_i,f(x_i))=\max(0,1-y_if(x_i)) L(yi,f(xi))=max(0,1yif(xi))
    其中:

    • y i ∈ { − 1 , 1 } y_i \in \{-1, 1\} yi{1,1} 是样本i的真实标签(SVM通常处理二分类问题)
    • f ( x i ) = w ⊤ x i + b f(x_i)=\mathbf{w}^\top x_i+b f(xi)=wxi+b 是模型对样本 x i x_i xi的预测结果,表示超平面w和 x i x_i xi的内积再加上偏置b
    • 1 − y i f ( x i ) 1-y_if(x_i) 1yif(xi) 是衡量样本离决策边界的距离
    • 如果样本被正确分类且距离边界大于 1,损失为 0;否则,损失随着样本距离边界的接近或错误分类而增加。
  • 正则化项:SVM 模型的目标是找到能够最大化分类间距(margin)的超平面。因此,为了平衡分类误差和间距的最大化,损失函数通常还包括一个正则化项,用来控制模型的复杂度(即防止过拟合)。常见的正则化项是L2 正则化,其形式为:
    R ( w ) = 1 2 ∥ w ∥ 2 R(\mathbf{w})=\frac12\|\mathbf{w}\|^2 R(w)=21w2

  • 总损失函数:
    J ( w , b ) = 1 2 ∥ w ∥ 2 + C ∑ i = 1 m max ⁡ ( 0 , 1 − y i ( w ⊤ x i + b ) ) J(\mathbf{w},b)=\frac12\|\mathbf{w}\|^2+C\sum_{i=1}^m\max(0,1-y_i(\mathbf{w}^\top x_i+b)) J(w,b)=21w2+Ci=1mmax(0,1yi(wxi+b))

    • 其中:
      • C 是一个超参数,用于控制正则化项和合页损失之间的权衡。较大的 C 会减少分类错误,但可能导致过拟合;较小的 C 会增加容错性,防止过拟合。
      • m 是训练样本的数量。

梯度公式:

  • 对权重向量W的梯度:
    ∂ J ( w , b ) ∂ w = w − C ∑ i ∈ M y i x i \frac{\partial J(\mathbf{w},b)}{\partial\mathbf{w}}=\mathbf{w}-C\sum_{i\in\mathcal{M}}y_ix_i wJ(w,b)=wCiMyixi

  • 对偏置b的梯度:
    ∂ J ( w , b ) ∂ b = − C ∑ i ∈ M y i \frac{\partial J(\mathbf{w},b)}{\partial b}=-C\sum_{i\in\mathcal M}y_i bJ(w,b)=CiMyi

核技巧(kernel Trick): 当数据无法通过线性超平面分割时,SVM使用核技巧将数据映射到高维空间。常见的核函数包括:

  • 多项式核(Polynomial Kernel):将原始数据通过多项式映射到高维空间
  • 高斯核/径向基核(RBF Kernel):将数据点投影到无穷维空间,使得非线性数据在高维空间中变得线性可分

梯度下降

是一种用于优化线性回归和线性分类模型的迭代方法,通过计算损失函数的梯度,并沿着梯度的反方向迭代更新参数 W W W,逐步逼近最优解。梯度下降有三种主要变体:

  • 批量梯度下降(Batch Gradient Descent, BGD)
  • 随机梯度下降(Stochastic Gradient Descent, SGD)
  • 随机批量梯度下降(Mini-Batch Gradient Descent, MBGD)

样本数的对梯度公式的影响:

  • 如果不除以样本数,计算得到的梯度是累计的梯度,也就是每个样本的误差对权重的累积影响
  • 如果除以样本数,计算得到的是平均梯度,每次更新会使用平均误差对权重进行更新

参数更新公式:
W = W − η ∇ L ( W ) W=W-\eta\nabla L(W) W=WηL(W)
其中:

  • W W W是参数向量
  • η \eta η是学习率,控制更新的步长
  • 损失函数 J ( W ) J(W) J(W),在线性回归中有介绍,不同的模型可以选取不同的损失函数
  • ∇ L ( W ) \nabla L(W) L(W)是损失函数 J ( W ) J(W) J(W)对参数W的梯度(也就是求导,一般来说,如果损失函数是每个样本损失的平均值,也就是除以了样本数m,那损失的函数的梯度就不再需要除以样本数m了

批量梯度下降

概述: 在每次迭代中,使用整个训练集来计算梯度

过程:

  1. 首先,初始化模型参数
  2. 定义损失函数 J ( W ) J(W) J(W)
  3. 求解损失函数的导数,也就是损失函数的梯度函数 L ( W ) L(W) L(W)
  4. 根据上面给的参数更新公式更新W
  5. 重复迭代

优点:

  • 更新稳定,避免了由样本噪声引起的波动
  • 收敛到全局最优解,梯度方向更精确

缺点:

  • 计算成本高
  • 不适合大规模数据集

随机梯度下降

概述: 在每次迭代中,只使用一个样本计算梯度

SGD优缺点:

  • 优点:
    • 效率高:对于大规模数据集,SGD不需要每次都遍历整个数据集,它每次只对一个样本进行更新,使得计算更快
    • 内存友好:由于它只需要处理一个样本,内存消耗相对较低,适合处理大数据
    • 在线学习:SGD可以随着新数据的到来在线更新模型, 而不需要每次都从头开始训练
  • 缺点:
    • 噪声较大:由于每次更新使用的是单个样本,更新方向可能不是全局最优,因此SGD的收敛路径往往比较噪声且不稳定
    • 需要调整学习率:学习率 η \eta η的选择至关重要。如果学习率过大,参数更新可能会错过最优点;如果学习率过小,收敛速度将非常慢

学习率衰减策略:

为了解决噪声问题,常见的做法是在训练过程中逐渐减低学习率,这种方法可以在训练初期进行较大步长的更新,使得模型快速接近最优解,而在后期逐渐减小步长,使得模型在最优解附件收敛

学习率衰减公式的常见形式是:
KaTeX parse error: Expected 'EOF', got '_' at position 40: …1 + \text{decay_̲rate} \cdot t}
其中:

  • η 0 \eta_0 η0是初始学习率
  • t t t是当前的迭代次数
  • KaTeX parse error: Expected 'EOF', got '_' at position 12: \text{decay_̲rate}是学习率的衰减系数

伪代码:

初始化 W
for epoch in range(num_epochs):for i in range(m):  # m是样本数量随机选取一个样本 (x_i, y_i)计算该样本的梯度: grad = x_i * (x_i W - y_i)取负梯度方向更新参数: W = W - η * grad记录训练集或验证集的损失

批量随机梯度下降

概述: 每次使用一小批随机样本计算梯度

过程:

  1. 初始化权重W和b

  2. 选择批量大小B(比如32,64等

  3. 每次迭代时,从训练集中随机抽取一批样本,计算该批样本上的损失函数梯度,然后更新权重:
    W = W − η ∇ L ( W ) W=W-\eta\nabla L(W) W=WηL(W)
    其中 η \eta η是学习率, η ∇ L ( W ) \eta\nabla L(W) ηL(W)是对权重的梯度

优点:

  • 在每次更新时引入随机性,避免陷入局部最优解
  • 更新效率较高,能在大规模数据集上加速训练
  • 在批量计算中还可以利用并行化处理,进一步提高效率

http://www.ppmy.cn/news/1533089.html

相关文章

【星海saul随笔】Ubuntu基础知识

网卡 /etc/netplan/00-installer-config.yaml network:ethernets:ens33:addresses:- 172.16.x.x/24nameservers:addresses:- 223.5.5.5routes:- to: 0.0.0.0/0via: 172.16.x.254version: 2SSH 远程登录其他主机不进行确认 ##SSH 远程登录其他主机不进行确认 sed -i s/# Stric…

SpringCloud 2023 LoadBalancer介绍、使用、获取服务列表原理、负载均衡算法

目录 1. 介绍2. 使用3 获取服务列表原理4. 负载均衡算法 1. 介绍 功能: 提供客户端的负载均衡算法,将请求均摊到多个服务器上。属于客户端负载均衡(Nginx属于服务端负载均衡),会将服务列表缓存到JVM本地,然后客户端自己选择请求服务器支持S…

Python知识点:如何使用Spark与PySpark进行分布式数据处理

开篇,先说一个好消息,截止到2025年1月1日前,翻到文末找到我,赠送定制版的开题报告和任务书,先到先得!过期不候! Apache Spark 是一个强大的分布式数据处理系统,而 PySpark 是 Spark …

VMware搭建DVWA靶场

目录 1.安装phpstudy 2.搭建DVWA 本次搭建基于VMware16的win7系统 1.安装phpstudy 下载windows版本:小皮面板-好用、安全、稳定的Linux服务器面板! 安装后先开启mysql再开启apache,遇到mysql启动不了的情况,最后重装了phpstud…

第十四章:html和css做一个心在跳动,为你而动的表白动画

💖 让心跳加速,传递爱意 💖 在这个特别的时刻,让爱在跳动中绽放!🌟 无论是初次相遇的心动,还是陪伴多年的默契,我们的心总在为彼此跳动。就像这颗炙热的爱心,随着每一次的跳动,传递着满满的温暖与期待。 在这个浪漫的季节,让我们一同感受爱的律动!无论你是在…

秋招内推--招联金融2025

【投递方式】 直接扫下方二维码,或点击内推官网https://wecruit.hotjob.cn/SU61025e262f9d247b98e0a2c2/mc/position/campus,使用内推码 igcefb 投递) 【招聘岗位】 后台开发 前端开发 数据开发 数据运营 算法开发 技术运维 软件测试 产品策…

Midjourney中文版:解锁AI艺术创作的无限潜能

在数字化时代,艺术创作与科技的融合正以前所未有的速度推进,而Midjourney中文版正是这一趋势下的璀璨明星。作为一款专为中文用户设计的AI绘图工具,它不仅集成了最先进的深度学习技术,还通过本地化优化,为国内设计师和…

实战笔记:Vue2项目Webpack 3升级到Webpack 4的实操指南

在Web开发领域,保持技术的更新是非常重要的。随着前端构建工具的快速发展,Webpack已经更新到5.x版本,如果你正在使用Vue2项目,并且还在使用Webpack 3,那么是时候考虑升级一下Webpack了。我最近将我的Vue2项目从Webpack…