从零深度学习:(2)最小二乘法

server/2025/1/18 9:11:55/

今天我们从比较简单的线性回归开始讲起,还是一样我们先导入包

import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
a = torch.arange(1,5).reshape(2,2).float()
a

我们利用刚刚导入的画图的包将这两个点画出来,将1和3先索引出来作为横坐标,2和4作为纵坐标传入给plot,'o'表示画的是点而不是线

#画出上面两个点
plt.plot(a[:,0],a[:,1],'o')

                 

现在我们希望找到一条直线去穿过拟合这两个点,也就是所谓的线性回归,不妨设方程如下:

y = ax+b

我们在初中就学过两个点能够带入两个方程进行求解,将a和b通过解方程的形式求解出来。除了这种矩阵求解以外,我们还可以转化为一个优化问题来进行求解。其中优化问题最关键的两个就是优化指标和优化目标函数。我们现在的任务是找到一条直线拟合这两个点,所以显然目标就是将这两个点横坐标带进方程解析式的预测值y和实际的y(2,4)之间的误差变小。

我们可以在markdown中渲染得到如下表格,右侧是预测值和真实值的差值:

为了让差距变小,一个很朴素的想法就是求和变的最小,但是由于这里有正有负,可能会出现正负抵消的情况,所以这里我们采用先平方再求和,也就是所谓的误差平方和SSE:

至此我们已经完成了优化问题的转化,我们现在的目标就是找到a和b为何值的时候这个差值函数最小,因此上面这个函数也叫做目标函数。

我们导入画图工具包将这个函数图像画出来:

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x = np.arange(-1, 3, 0.1)  # 增加步长,减少数据点数量
y = np.arange(-1, 3, 0.1)
a, b = np.meshgrid(x, y)
SSE = (2 - a - b)**2 + (4 - 3*a - b)**2fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(a, b, SSE, cmap='rainbow')
plt.show()

我们不难看出这是一个凸函数,而对于一个凸函数来说,最小值显然存在,所以根据这一点,我们可以给出求解凸函数最小值的一般方法,也就是最小二乘法。关于这个最小二乘法,我们在高数和概率论的学习中都有涉及,如果有不懂的宝子可以去补一下。当然凸函数优化方法还有很多,我们会在后续的学习中陆续提及。

所以这里就是对a和b分别求偏导并令其等于0,即可求解出来。

求解得出方程为:y = x +1,也就是说当(a,b)等于(1,1)的时候函数取得最小值。

在求解完之后我们也可以通过借助autograd模块来帮助我们验证导数是否为零。

autograd

我们可以在jupyter中输入如上代码,会发现如果你的张量requires_grad属性等于True,你每计算一步都会记录在grad_fn当中。例如这里的y是通过x乘法得到的,所以下面是Mul也就是乘法的缩写,同理z的Pow是power的缩写。

我们也可以通过.grad_fn来查看具体内容:

grad_fn 存储了当前张量的计算源和操作类型,用于梯度计算。具体来说,它指向一个与该张量相关的操作对象,操作对象是由上次计算生成的,这些对象的存在是为了在反向传播时提供梯度计算的方法,并且它们是由 PyTorch 自动生成并维护的。

同时,这是链式存储的一部分。在反向传播中,PyTorch 会按照计算图的反向顺序计算每个张量的梯度。这些 grad_fn 实际上是梯度计算的链条,记录了张量是如何从前一个操作得到的,并允许在反向传播时依赖于这些操作生成梯度。

所以根据这个回溯机制,我们可以画出输出张量是怎么一步一步得来的并画出张量计算图,如下:

PyTorch的计算图是动态计算图,会根据可微分张量的计算过程自动生成,并且伴随着新张量或运算的加入不断更新,这使得PyTorch的计算图更加灵活高效,并且更加易于构建,动态图也更加适用于面向对象编程。


http://www.ppmy.cn/server/159315.html

相关文章

数据结构-栈队列OJ题

文章目录 一、有效的括号二、用队列实现栈三、用栈实现队列四、设计循环队列 一、有效的括号 (链接:ValidParentheses) 这道题用栈这种数据结构解决最好,因为栈有后进先出的性质。简单分析一下这道题:所给字符串不是空的也就是一定至少存在一…

C# 线程基础之 线程同步

线程同步的手段很多 lock 是通过内存索引块 0 1 切换 进行互斥的实现 互斥量 信号量 事件消息 其实意思就是 一个 标记量 通过这个标记 来进行类似的互斥手段 具体方式的分析 代码在后 1.互斥量 Mutex 作用 非常类似lock 一个Mutex 名称来代替 lock的引用对象 2.信号量 Semaph…

ASP.Net Identity + IODC 解析ReturnUrl

Identity Ids4 配置 成认证服务 一、创建 Identity身份验证 项目 创建的项目结构中 没有 注册和登录的 控制器和视图 配置数据库地址 》》默认已经生成了Miagratin 直接update-database 二、在Identity项目 配置 IdentityServer4 Nuget 两个包 》》》配置Config 类 usin…

鸿蒙Flutter实战:16-无痛开发指南(适合新手)

本文讲述如何通过 Flutter 开发鸿蒙原生应用。整个过程结合往期文章、实战经验、流程优化,体验丝滑、无痛。 无痛搭建开发环境 为了减少疼痛,这里使用全局唯一的 Flutter 版本开发。高阶用法可以参看往期同系列文章。 硬件准备 一台 Mac,一部…

当设置dialog中有el-table时,并设置el-table区域的滚动,看到el-table中多了一条横线

问题:当设置dialog中有el-table时,并设置el-table区域的滚动,看到el-table中多了一条横线; 原因:el-table有一个before的伪元素作为表格的下边框下,初始的时候已设置,在滚动的时候并没有重新设置…

BGP边界网关协议(Border Gateway Protocol)概念、邻居建立

一、定义 主要用于交换AS之间的可达路由信息,构建AS域间的传播路径,防止路由环路的产生,并在AS级别应用一些路由策略。当前使用的版本是BGP-4。 二、环境 底层以OSPF进行igp互联互通,上层使用BGP协议。 三、基本原理 1、BGP是一…

unity学习16:unity里向量的计算,一些方法等

目录 1 unity里的向量: 2 向量加法 2.1 向量加法的几何意义 2.2向量加法的标量算法 3 向量减法 3.1 向量减法的几何意义 3.2 向量减法的标量算法 4 向量的标量乘法 5 向量之间的乘法要注意是左乘 还是右乘 5.1 注意区别 5.2 向量,矩阵&#x…

Unity3D BEPUphysicsint定点数3D物理引擎使用详解

前言 Unity3D作为一款强大的游戏开发引擎,提供了丰富的功能和工具,助力开发者轻松创建多样化的游戏。而在游戏开发中,物理引擎的作用不可忽视。BEPUphysicsint是一个基于Unity3D的开源3D物理引擎项目,它通过采用定点数计算来实现…