【ShuQiHere】从零开始实现线性回归:深入理解反向传播与梯度下降

embedded/2024/10/15 18:23:42/

【ShuQiHere】

线性回归是一种简单但强大的回归分析方法,主要用于预测连续值。它在许多领域都有广泛的应用,尤其是当我们需要根据已有数据来预测未来的趋势时,线性回归显得尤为重要。虽然它是机器学习中最基础的算法之一,但理解其原理对掌握更复杂的算法至关重要。本文将带你一步步从零开始实现线性回归,并深入探讨反向传播与梯度下降这两个核心算法,帮助你打下扎实的基础。

线性回归的数学基础

线性回归的目标是找到一个线性函数,该函数能够尽可能准确地预测目标变量 ( Y ) 的值。这个线性函数的形式如下:

[
Y = W X + b Y = WX + b Y=WX+b
]

在这个公式中:

  • ( W ) 是权重向量,表示每个输入特征对输出的影响程度。
  • ( X ) 是输入特征向量,即我们用来进行预测的输入数据。
  • ( b ) 是偏置项,它帮助模型更好地拟合数据,特别是在输入特征的值很小时。

简单来说,线性回归试图找到一条直线(或者在多维情况下的一个超平面),使得这条线尽可能接近数据点。

损失函数

为了衡量模型的预测值与实际值之间的差异,我们使用均方误差(Mean Squared Error, MSE)作为损失函数。均方误差计算的是预测值与真实值之间的平均平方差异:

[
MSE = 1 2 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) 2 \text{MSE} = \frac{1}{2m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)})^2 MSE=2m1i=1m(Y(i)Y^(i))2
]

在这里:

  • ( m ) 是样本数量。
  • ( \hat{Y}^{(i)} ) 是第 ( i ) 个样本的预测值。
  • ( Y^{(i)} ) 是第 ( i ) 个样本的真实值。

损失函数的值越小,说明模型的预测值越接近真实值。我们希望通过调整模型参数来最小化这个损失。

反向传播(Backward Propagation)

反向传播是一种计算梯度的算法,用于指导模型如何更新参数以减小损失。在反向传播中,我们首先计算损失函数相对于模型参数(即 ( W ) 和 ( b ) )的导数(梯度)。这些梯度表示了损失函数相对于权重和偏置的变化速率,它们告诉我们如何调整这些参数才能更快地减小损失。

对于线性回归模型,损失函数对权重 ( W ) 和偏置 ( b ) 的梯度计算如下:

[
∂ Loss ∂ W = 1 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) × X ( i ) \frac{\partial \text{Loss}}{\partial W} = \frac{1}{m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)}) \times X^{(i)} WLoss=m1i=1m(Y(i)Y^(i))×X(i)
]

[
∂ Loss ∂ b = 1 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) \frac{\partial \text{Loss}}{\partial b} = \frac{1}{m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)}) bLoss=m1i=1m(Y(i)Y^(i))
]

这些公式的含义是:我们计算每个样本的预测误差,然后根据这些误差的方向和大小来更新模型的权重和偏置,以减少整体损失。

反向传播的实现

def backward_propagation(X, A, Y):m = X.shape[1]dw = (1/m) * np.dot(X, (A - Y).T)db = (1/m) * np.sum(A - Y)return dw, db

在这个实现中,dw 是权重的梯度,db 是偏置的梯度。我们通过这些梯度来指导参数的更新。

梯度下降(Gradient Descent)

梯度下降是一种迭代优化算法,用于更新模型的参数,使损失函数达到最小值。其核心思想是沿着损失函数下降最快的方向(即梯度的反方向)更新参数。

梯度下降的更新规则如下:

[
W : = W − α × ∂ Loss ∂ W W := W - \alpha \times \frac{\partial \text{Loss}}{\partial W} W:=Wα×WLoss
]

[
b : = b − α × ∂ Loss ∂ b b := b - \alpha \times \frac{\partial \text{Loss}}{\partial b} b:=bα×bLoss
]

在这里:

  • ( \alpha ) 是学习率,它决定了每次参数更新的步伐大小。如果学习率太大,模型可能会跳过最优解;如果学习率太小,收敛速度会很慢。

梯度下降的实现

def update_parameters(w, b, dw, db, learning_rate):w = w - learning_rate * dwb = b - learning_rate * dbreturn w, b

在这个实现中,我们使用计算得到的梯度 dwdb 来更新权重 w 和偏置 b ,以逐步减小损失。

线性回归的实现步骤

理解了反向传播和梯度下降之后,我们可以开始实现一个完整的线性回归模型。我们将按照以下步骤进行实现:

1. 初始化参数

首先,我们需要初始化模型的权重 ( W ) 和偏置 ( b )。在这个例子中,我们将权重初始化为零或小的随机值,偏置初始化为零。

def initialize_parameters(dim):w = np.zeros((dim, 1))b = 0return w, b

2. 前向传播

前向传播用于计算模型的输出(预测值)。这是通过将输入特征与权重相乘并加上偏置来实现的。

def forward_propagation(X, w, b):return np.dot(w.T, X) + b

3. 计算损失

我们使用均方误差来衡量预测值与真实值之间的差异。这一步骤非常关键,因为它告诉我们当前模型的性能如何。

def compute_loss(A, Y):m = Y.shape[1]loss = (1/(2*m)) * np.sum((A - Y) ** 2)return loss

4. 模型训练

在模型训练过程中,我们通过多次迭代,不断进行前向传播、计算损失、反向传播以及参数更新,最终得到一个能够准确预测的模型。

def train(X, Y, num_iterations, learning_rate):w, b = initialize_parameters(X.shape[0])for i in range(num_iterations):A = forward_propagation(X, w, b)loss = compute_loss(A, Y)dw, db = backward_propagation(X, A, Y)w, b = update_parameters(w, b, dw, db, learning_rate)if i % 100 == 0:print(f"Iteration {i}, Loss: {loss}")return w, b

在这个实现中,我们通过多次迭代来训练模型,并在每 100 次迭代时输出当前的损失值,以便跟踪模型的学习进度。

5. 模型评估

最后,我们可以使用均方误差来评估模型的性能,查看模型在测试集上的表现如何。

def evaluate(X, Y, w, b):A = forward_propagation(X, w, b)return compute_loss(A, Y)

结论

线性回归虽然简单,但它是机器学习中至关重要的基础模型。通过深入理解其实现过程中的反向传播和梯度下降,我们可以更好地理解机器学习的核心思想。这些知识不仅有助于掌握线性回归的实现,还为学习更复杂的机器学习模型打下了坚实的基础。希望本文的讲解能帮助你更好地理解线性回归,并激发你对机器学习更深层次的探索欲望。


http://www.ppmy.cn/embedded/101534.html

相关文章

vue中使用vue-video-player插件播放视频 以及 audio播放音频

一、使用vue-video-player插件播放视频 安装 npm install vue-video-player --save 在main.js中引用 //引入视频播放插件 // main.js import VueVideoPlayer from vue-video-player import video.js/dist/video-js.css import vue-video-player/src/custom-theme.cssVue.use(V…

CPU、MPU、MCU、SOC分别是什么?

CPU、MPU、MCU和SoC都是与微电子和计算机科学相关的术语,它们在功能定位、应用场景以及处理能力等方面有所区别。具体如下: CPU:CPU是中央处理单元的缩写,它通常指计算机内部负责执行程序指令的芯片。CPU是所有类型计算机&#x…

SNP亮相 2024 SAP高科技行业峰会:科技新引擎 智领新增长

8月15日,以“科技新引擎 智领新增长”为主题的2024思爱普中国峰会行业论坛——SAP高科技行业峰会在上海成功举办。SNP中国受邀参与本次峰会,并发表主题演讲《云时代企业ERP升级创新实践案例》。合作伙伴ABeam、微软及100多位高科技行业CIO共同出席了本次…

SFF1604-ASEMI无人机专用SFF1604

编辑:ll SFF1604-ASEMI无人机专用SFF1604 型号:SFF1604 品牌:ASEMI 封装:ITO-220AB 批号:最新 恢复时间:35ns 最大平均正向电流(IF):16A 最大循环峰值反向电压&a…

Python 爬虫框架

Python 中有许多强大且主流的爬虫框架,这些框架提供了更高级的功能,使得开发和维护爬虫变得更加容易。以下是一些常用的爬虫框架: 1. Scrapy - 简介: Scrapy 是 Python 最流行的爬虫框架之一,设计用于快速、高效地从网站中提取…

回答评论:使用流遍历文件 list

网友视频评论 回答评论: arraylist里包含了一个文件夹内部文件和子文件夹 怎么使用steam 可以遍历整个文件夹 最后生成的集合里是所有的文件路径,比如D:/test test文件夹里面有1.mp4, test2,test3的文件夹, test2和test3内部也嵌套了文件夹和…

开源程序实操:岩土工程渗流问题的有限单元法应用

有限单元法在岩土工程问题中应用非常广泛,很多商业软件如Plaxis/Abaqus/Comsol等都采用有限单元解法。尽管各类商业软件使用方便,但其使用对用户来说往往是一个“黑箱子”。相比而言,开源的有限元程序计算方法透明、计算过程可控,…

STM案例一:灯闪烁

一、使用元件 STlink,STM32F103C8T6 二、接线方法 STM32与STLINK的接线方法为: GND-->GND DCLK-->SWCLK DIO-->SWDIO 3.3-->3.3V 三、配置调试器 选择魔术棒按钮,单击Debug,选择ST-link Debug,选…