pytorch直线拟合

news/2024/11/30 5:38:01/

目录

1、数据分析

2、pytorch直线拟合


1、数据分析

直线拟合的前提条件通常包括以下几点:

存在线性关系:这是进行直线拟合的基础,数据点之间应该存在一种线性关系,即数据的分布可以用直线来近似描述。这种线性关系可以是数据点在直角坐标系上的分布趋势,也可以是通过实验或观测得到的数据点之间的关系。

数据点之间的误差是随机的:误差应该是随机的,没有任何系统性的偏差,并且符合随机误差的统计规律。这意味着数据点在拟合直线周围的分布应该是随机的,而不是受到某种特定的规律或趋势的影响。

直线应符合数据点的总体趋势:在拟合直线时,应该尽可能地符合数据点的总体趋势,而不是被一些异常值所影响。如果存在一些异常值,它们不应该对拟合结果产生过大的影响。

数据点的数量应该足够多:在进行直线拟合时,需要足够多的数据点来保证拟合结果的准确性和可靠性。通常来说,数据点的数量应该足够多,以便涵盖各种情况,并且能够反映出数据的真实分布情况。

数据的观测或实验过程是可靠的:数据的观测或实验过程应该是可靠的,这意味着数据的测量值应该是准确的,并且没有受到某些特定因素的影响。如果数据的观测或实验过程存在偏差或误差,那么直线拟合的结果也可能受到影响。

从散点图看出,数据具有明显的线性关系​,本例不过多讨论数据是满足直线拟合的其它条件。

import torch
import matplotlib.pyplot as plt
x=torch.Tensor([1.4,5,11,16,21])
y=torch.Tensor([14.4,29.6,62,85,113.4])
plt.scatter(x.numpy(),y.numpy())
plt.show()

2、pytorch直线拟合

基于梯度下降法实现直线拟合。训练过程实际上是一种批量梯度下降(Batch Gradient Descent),这是因为每次更新参数时都使用了所有的数据。另外,学习率 learning_rate 和训练轮数 epochs 是可以调整的超参数,对模型的训练效果有很大影响。

import torch
import matplotlib.pyplot as plt
def Produce_X(x):x0=torch.ones(x.numpy().size)X=torch.stack((x,x0),dim=1)return X
def train(epochs=1,learning_rate=0.01):for epoch in range(epochs):output=inputs.mv(w)loss=(output-target).pow(2).sum()loss.backward()w.data-=learning_rate*w.gradw.grad.zero_()if epoch%80==0:draw(output,loss)return w,loss
def draw(output,loss):plt.cla()plt.scatter(x.numpy(), y.numpy())plt.plot(x.numpy(),output.data.numpy(),'r-',lw=5)plt.text(5,20,'loss=%s' % (loss.item()),fontdict={'size':20,'color':'red'})plt.pause(0.005)
​
if __name__ == "__main__":x = torch.Tensor([1.4, 5, 11, 16, 21])y = torch.Tensor([14.4, 29.6, 62, 85.5, 113.4])X = Produce_X(x)inputs = Xtarget = yw = torch.rand(2, requires_grad=True)w,loss=train(10000,learning_rate=1e-4)print("final loss:",loss.item())print("weigths:",w.data)plt.show()
​

final loss: 8.216197967529297

weigths: tensor([5.0817, 5.6201])


http://www.ppmy.cn/news/1201161.html

相关文章

Go的错误处理

什么是错误? 错误表示程序中发生的任何异常情况。假设我们正在尝试打开一个文件,但该文件在文件系统中不存在。这是一种异常情况,表示为错误。 Go 中的错误是普通的旧值。就像任何其他内置类型(例如 int、float64 等&#xff09…

仿mudou库one thread one loop式并发服务器

目录 1.实现目标 2.HTTP服务器 实现高性能服务器-Reactor模型 模块划分 SERVER模块: HTTP协议模块: 3.项目中的子功能 秒级定时任务实现 时间轮实现 正则库的简单使用 通⽤类型any类型的实现 4.SERVER服务器实现 日志宏的封装 缓冲区Buffer…

git使用全解析 | git的原理 配置 基础使用 分支 合并

文章目录 1 git初步了解1.1 git的安装1.2 git原理模型1.3 git基础配置1.4 git基础用法1 将文件加入暂存区2 查看当前的git仓库状态3 删除文件4 commit 将暂存区文件加入本地git版本仓库5 查看提交历史 更改 2 分支2.1 创建分支2.2 查看分支2.3 切换分支2.4 内容比较 3 合并 本文…

python opencv 实现对二值化后的某一像素值做修改和mask叠加

实现对二值化后的某一像素值做修改 使用OpenCV的findNonZero函数找到所有非零(也就是像素值为255)的像素,然后遍历这些像素并修改他们的值。示例代码: import cv2 import numpy as np # 加载并二值化图像 img cv2.imread(…

大厂真题:【模拟】阿里蚂蚁2023秋招-讨厌鬼的区间

题目描述与示例 题目描述 讨厌鬼有一个数x,他每次操作可以令x x 1或x x - 1 讨厌鬼还有两个区间[l1, r1]和[l2, r2],讨厌鬼想知道,令x同时满足以下条件的最小操作数是多少? l1 ≤ x ≤ r1,且x是2的倍数l2 ≤ x ≤ r2&#xf…

Pandas数据分析Pandas进阶在线闯关_头歌实践教学平台

Pandas数据分析进阶 第1关 Pandas 分组聚合第2关 Pandas 创建透视表和交叉表 第1关 Pandas 分组聚合 任务描述 本关任务:使用 Pandas 加载 drinks.csv 文件中的数据,根据数据信息求每个大洲红酒消耗量的最大值与最小值的差以及啤酒消耗量的和。 编程要求…

Linux文件系统目录结构

典型的Linux文件系统目录结构的列表 典型的Linux文件系统目录结构的列表。每个目录都有其特定的用途: /bin: 存放系统引导和修复所需的二进制可执行文件,如ls,cp,mv等命令。 /boot: 存放操作系统引导文件,例如内核和…

基于SSM的出租车管理系统

基于SSM的出租车管理系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringSpringMVCMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 登录界面 管理员界面 驾驶员界面 摘要 基于SSM(Spring、Spring MVC、My…