线性回归(一)基于Scikit-Learn的简单线性回归

ops/2025/2/28 2:40:04/

主要参考学习资料:

机器学习算法的数学解析与Python实现》莫凡 著

前置知识:线性代数-Python

目录

  • 问题背景
  • 数学模型
    • 假设函数
    • 损失函数
    • 优化方法
    • 训练步骤
  • 代码实现
  • 特点

问题背景

回归问题是一类预测连续值的问题,满足这样要求的数学模型称作回归模型

线性方程指未知数都是一次的方程,其图像为一条直线。

线性回归问题的回归模型使用线性方程,适用于数据集点沿线性分布的场景。

数学模型

假设函数

假设函数是一个将输入映射到输出的函数,用于预测输出变量的值。

线性回归模型的假设函数:

H ( x ) = w T x i + b H(x)=\boldsymbol w^Tx_i+b H(x)=wTxi+b

其中 w \boldsymbol w w为模型参数/权重, x i x_i xi为模型输入,均为 n n n维向量。 b b b为偏置项。

损失函数

损失函数是一个体现预测值与真实值的偏差的函数。

线性回归模型的损失函数:

L ( x ) = ∥ y ^ − y ∥ 2 2 L(x)=\left\|\hat y-y\right\|^2_2 L(x)=y^y22

其中 y ^ \hat y y^为预测值, y y y为真实值。

符号 ∥ ∥ \left\|\right\| 范数正则化,简称范数。下标 n n n表示 L n \mathrm Ln Ln范数,即 n n n维欧几里得空间的距离,例如:

∥ x ∥ 1 = ∑ i = 1 n ∣ x i ∣ \left\|x\right\|_1=\displaystyle\sum^n_{i=1}|x_i| x1=i=1nxi

∥ x ∥ 2 = ∑ i = 1 n x i 2 \left\|x\right\|_2=\displaystyle\sqrt{\sum^n_{i=1}x_i^2} x2=i=1nxi2

优化方法

优化方法是以损失函数为依据将偏差减到最小的方法,通常使用梯度下降等现成算法,此处即通过调节参数 w \boldsymbol w w b b b使损失函数求得最小值:

min ⁡ w , b ∥ y ^ − y ∥ 2 2 \underset{\mathrm w,b}{\min}\left\|\hat y-y\right\|^2_2 w,bminy^y22

w \boldsymbol w w为例,其调节方法为 :

w 新 = w 旧 − 学习率 ∗ 损失值 \boldsymbol w_新=\boldsymbol w_旧-学习率*损失值 w=w学习率损失值

学习率是一个由外部输入用于控制训练过程的参数,称为超参数,影响每次偏差带来的参数调整幅度。

损失值可通过损失函数对 w \boldsymbol w w求偏导得出。

训练步骤

①为假设函数初始化参数 w \boldsymbol w w b b b

②将每个训练样本 x i x_i xi代入假设函数,最终计算损失值。

③利用优化方法调整假设函数的参数,重复以上步骤使得损失值最小。

代码实现

Scikit-Learn对各类机器学习算法进行了良好封装,可以调用简单的函数来实现模型训练,安装命令为:

pip install -U scikit-learn

基于Scikit-Learn库的线性回归算法

python">#从Scikit-Learn库导入线性模型
from sklearn import linear_model import matplotlib.pyplot as plt  
import numpy as np  #生成数据集,样本特征x为间隔均匀的序列,结果y由线性方程给出
x = np.linspace(-3, 3, 30)  
y = 2*x + 1  #将数据集从一维数组转换为二维数组以符合scikit-learn的输入要求
x1 = [[i] for i in x]  
y1 = [[i] for i in y]  #创建线性回归模型
model = linear_model.LinearRegression()  
#训练模型
model.fit(x1, y1)  #绘制拟合线条,predict方法返回模型对输入的预测值
plt.plot(x, model.predict(x1), color='red')  
#绘制原始数据点
plt.scatter(x, y)  
#显示图像
plt.show()

运行结果:

若要添加随机扰动,生成较不规则的数据集,可将代码对应部分替换为:

x = np.linspace(-3, 3, 30)  
y = 2*x + 1  
x = x+np.random.rand(30)  

运行结果:

可以通过 model.coef _ \texttt{model.coef}\_ model.coef_ model.intercept _ \texttt{model.intercept}\_ model.intercept_得到模型当前 w \boldsymbol w w b b b的值。

特点

优点:形式简单,可解释性强,容易理解和实现。

缺点:不能表达复杂的模式,对于非线性问题表现不佳。

应用领域:金融、气象预报等能够用线性关系进行描述的问题领域。


http://www.ppmy.cn/ops/161832.html

相关文章

HW面试经验分享 | 北京蓝中研判岗

目录: 所面试的公司介绍 面试官的问题: 1、面试官先就是很常态化的让我做了一个自我介绍 2、自我介绍不错,听你讲熟悉TOP10漏洞,可以讲下自己熟悉哪些方面吗? 3、sql注入原理可以讲下吗? 4、sql注入绕WAF有…

.NET Core MVC IHttpActionResult 设置Headers

最近碰到调用我的方法要求返回一个代码值,但是要求是不放在返回实体里,而是放在返回的Headers上 本来返回我是直接用 return Json(res) 这种封装的方法特别简单,但是没有发现设置headers的地方 查询过之后不得已换了个返回 //原来方式 //…

通过命令启动steam的游戏

1. 启动Steam客户端 在命令行输入以下命令来启动Steam客户端: start steam://open/main 如果Steam未安装在默认路径,可能需要先定位到Steam的安装目录,例如: cd C:\Program Files (x86)\Steam start steam://open/main 2. 通过…

机器人部分专业课

华东理工 人工智能与机器人导论 Introduction of Artificial Intelligence and Robots 必修 考查 0.5 8 8 0 1 16477012 程序设计基础 The Fundamentals of Programming 必修 考试 3 64 32 32 1 47450012 算法与数据结构 Algorithm and Data Structure 必修 考试 3 56 40 …

python爬虫系列课程4:一个例子学会使用xpath语法

python爬虫系列课程4:一个例子学会使用xpath语法 本文通过一个例子,学会xpath的各种语法,可以作为xpath的查询手册使用,代码如下: from lxml import etreetext = <div> <ul><li class="item-1"><a href="link1.html">first i…

Centos服务器GCC安装

写在前面 唠叨两句 GCC是Linux系统中&#xff0c;进行C/C程序开发及运行的常用的工具包&#xff0c;很多软件安装的时候&#xff0c;需要使用这些工具包来进行运行。 本文所述是在Centos7的环境下进行 正文 gcc、gvv验证 使用如下命令验证gcc&#xff0c;如出现如图参数则…

10道Redis常见面试题速通

引言 本系列聚焦频率最高的面试题&#xff0c;用最简洁的文字表达中心思想&#xff0c;速通面试 1、Redis持久化数据和缓存怎么做扩容&#xff1f; 如果Redis被当做缓存使用&#xff0c;使用一致性哈希实现动态扩容缩容。如果Redis被当做一个持久化存储使用&#xff0c;必须使…

PydanticToolsParser 工具(tool call)把 LLM 生成的文本转成结构化的数据(Pydantic 模型)过程中遇到的坑

PydanticToolsParser 的作用 PydanticToolsParser 是一个工具&#xff0c;主要作用是 把 LLM 生成的文本转成结构化的数据&#xff08;Pydantic 模型&#xff09;&#xff0c;让代码更容易使用这些数据进行自动化处理。 换句话说&#xff0c;AI 生成的文本通常是自然语言&…