从零深度学习:(2)最小二乘法

ops/2025/1/17 22:29:46/

今天我们从比较简单的线性回归开始讲起,还是一样我们先导入包

import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
a = torch.arange(1,5).reshape(2,2).float()
a

我们利用刚刚导入的画图的包将这两个点画出来,将1和3先索引出来作为横坐标,2和4作为纵坐标传入给plot,'o'表示画的是点而不是线

#画出上面两个点
plt.plot(a[:,0],a[:,1],'o')

                 

现在我们希望找到一条直线去穿过拟合这两个点,也就是所谓的线性回归,不妨设方程如下:

y = ax+b

我们在初中就学过两个点能够带入两个方程进行求解,将a和b通过解方程的形式求解出来。除了这种矩阵求解以外,我们还可以转化为一个优化问题来进行求解。其中优化问题最关键的两个就是优化指标和优化目标函数。我们现在的任务是找到一条直线拟合这两个点,所以显然目标就是将这两个点横坐标带进方程解析式的预测值y和实际的y(2,4)之间的误差变小。

我们可以在markdown中渲染得到如下表格,右侧是预测值和真实值的差值:

为了让差距变小,一个很朴素的想法就是求和变的最小,但是由于这里有正有负,可能会出现正负抵消的情况,所以这里我们采用先平方再求和,也就是所谓的误差平方和SSE:

至此我们已经完成了优化问题的转化,我们现在的目标就是找到a和b为何值的时候这个差值函数最小,因此上面这个函数也叫做目标函数。

我们导入画图工具包将这个函数图像画出来:

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x = np.arange(-1, 3, 0.1)  # 增加步长,减少数据点数量
y = np.arange(-1, 3, 0.1)
a, b = np.meshgrid(x, y)
SSE = (2 - a - b)**2 + (4 - 3*a - b)**2fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(a, b, SSE, cmap='rainbow')
plt.show()

我们不难看出这是一个凸函数,而对于一个凸函数来说,最小值显然存在,所以根据这一点,我们可以给出求解凸函数最小值的一般方法,也就是最小二乘法。关于这个最小二乘法,我们在高数和概率论的学习中都有涉及,如果有不懂的宝子可以去补一下。当然凸函数优化方法还有很多,我们会在后续的学习中陆续提及。

所以这里就是对a和b分别求偏导并令其等于0,即可求解出来。

求解得出方程为:y = x +1,也就是说当(a,b)等于(1,1)的时候函数取得最小值。

在求解完之后我们也可以通过借助autograd模块来帮助我们验证导数是否为零。

autograd

我们可以在jupyter中输入如上代码,会发现如果你的张量requires_grad属性等于True,你每计算一步都会记录在grad_fn当中。例如这里的y是通过x乘法得到的,所以下面是Mul也就是乘法的缩写,同理z的Pow是power的缩写。

我们也可以通过.grad_fn来查看具体内容:

grad_fn 存储了当前张量的计算源和操作类型,用于梯度计算。具体来说,它指向一个与该张量相关的操作对象,操作对象是由上次计算生成的,这些对象的存在是为了在反向传播时提供梯度计算的方法,并且它们是由 PyTorch 自动生成并维护的。

同时,这是链式存储的一部分。在反向传播中,PyTorch 会按照计算图的反向顺序计算每个张量的梯度。这些 grad_fn 实际上是梯度计算的链条,记录了张量是如何从前一个操作得到的,并允许在反向传播时依赖于这些操作生成梯度。

所以根据这个回溯机制,我们可以画出输出张量是怎么一步一步得来的并画出张量计算图,如下:

PyTorch的计算图是动态计算图,会根据可微分张量的计算过程自动生成,并且伴随着新张量或运算的加入不断更新,这使得PyTorch的计算图更加灵活高效,并且更加易于构建,动态图也更加适用于面向对象编程。


http://www.ppmy.cn/ops/150935.html

相关文章

【云岚到家】-day02-客户管理-认证授权

第二章 客户管理 1.认证模块 1.1 需求分析 1.基础概念 一般情况有用户交互的项目都有认证授权功能,首先我们要搞清楚两个概念:认证和授权 认证: 就是校验用户的身份是否合法,常见的认证方式有账号密码登录、手机验证码登录等 授权:则是该用…

XML序列化和反序列化的学习

1、基本介绍 在工作中,经常为了调通上游接口,从而对请求第三方的参数进行XML序列化,这里常使用的方式就是使用JAVA扩展包中的相关注解和类来实现xml的序列化和反序列化。 2、自定义工具类 import javax.xml.bind.JAXBContext; import javax.x…

EF Core实体跟踪

快照更改跟踪 实体类没有实现属性值改变的通知机制,EF Core是如何检测到变化的呢? 快照更改跟踪:首次跟踪一个实体的时候,EF Core 会创建这个实体的快照。执行SaveChanges()等方法时,EF Core将会把存储的快照中的值与…

C# 多线程发展史(面试思路)

多线程技术 本身是为了 提高 cpu利用率 提高效率而生 因为 cpu分片机制 导致 多线程存在顺序与业务不符合情况 为了满足 正确的执行业务顺序 而诞生第一个要点线程同步 无论是控制主线程的同步等待 thread join task result task wait() 还是线程之间对于共享资源 同步的多种…

【MacOS】恢复打开系统设置的安全性的允许以下来源的应用程序的“任何来源”

在系统更新后,系统设置的安全性的允许以下来源的应用程序的“任何来源”可能会被修改为“来自APP开发者”。 操作步骤: So, I figured it out how to allow apps from anywhere. But learned its the order of operations on how to enable this optio…

STM32-Flash存储

目录 1.0 闪存模块组织 2.0 Flash基本结构 3.0 Flash解锁 4.0 指针访问存储器地址 5.0 程序存储器编程 6.0 选项字节 7.0 选项字节编程 8.0 选项字节擦除 9.0 电子签名 10.0 手册解读 定义: STM32F1系列的FLASH包含程序存储器、系统存储器和选项字节三个部…

Linux服务器网络丢包场景及解决办法

一、Linux网络丢包概述 在数字化浪潮席卷的当下,网络已然成为我们生活、工作与娱乐不可或缺的基础设施,如同空气般,无孔不入地渗透到各个角落。对于 Linux 系统的用户而言,网络丢包问题却宛如挥之不去的 “噩梦”,频繁…

智能物流升级利器——SAIL-RK3576核心板AI边缘计算网关设计方案(一)

近年来,随着物流行业智能化和自动化水平不断提升,数据的实时处理与智能决策成为推动物流运输、仓储管理和配送优化的重要手段。传统的集中式云平台虽然具备强大计算能力,但高延迟和带宽限制往往制约了物流现场的即时响应。为此,我…