使用线性回归模型逼近目标模型 | PyTorch 深度学习实战

news/2025/2/12 14:43:52/

前一篇文章,计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

使用线性回归模型逼近目标模型

什么是回归

在统计学中,回归分析(regression analysis)指的是确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。

简单说,就是使用统计学手段,分析变量之间的规律。发现规律后,可以根据给定的数据猜测特征空间的因变量1的数据。

在这里插入图片描述
参考文章:https://zhuanlan.zhihu.com/p/669597409

什么是线性回归

用一条直线去逼近数据的分布,参考定义:

A linear regression is a straight line that describes how the values of a response variable y y y change as the predictor variable x x x changes.

线性回归在实际中,可以包含多元的情况,比如:

z = w 1 x + w 2 y z = w_1 x + w_2 y z=w1x+w2y

更多线性回归介绍,参考文章。

使用 PyTorch 实现线性回归模型

实现量化投资:现在假如我们观测到了某支股票的数据 v v v,并且这支股票和石油的价格 x x x、黄金的价格 y y y和原煤的价格 z z z 有关联。因此,我们取得了不同时刻的 x x x y y y z z z 和对应的股票价格 v v v,现在,依据这些数据,建立一个方程式:

v = a x + b y + c z + d v = ax + by + cz + d v=ax+by+cz+d

此时,依赖历史采集的数据,我们来求 a,b,c,d 的值。
使用 PyTorch,这个程序实现如下。

代码

import torch
import matplotlib.pyplot as plt
import numpy as np# X and Y data,观测数据,包含多条
# 每条包含 3 个数据,分别代表石油、黄金、原煤的价格
x_data = [[65., 80., 75.],[89., 88., 93.],[80., 91., 90.],[30., 98., 100.],[50., 66., 70.]]
# 对应的这支股票的价格
y_data = [[152.],[185.],[189.],[196.],[142.]]# 定义输入 tensor 和输出 tensor 的变量
x=torch.autograd.Variable(torch.Tensor(x_data)) 
y=torch.autograd.Variable(torch.Tensor(y_data))# Our hypothesis XW+b,定义模型及参数
model=torch.nn.Linear(3,1,bias=True)# cost criterion,定义损失函数
criterion=torch.nn.MSELoss()# Minimize,优化器
optimizer=torch.optim.SGD(model.parameters(),lr=1e-7)# 训练轮数
epochs=200
cost_h=np.zeros(epochs)# Train the model,对于这个简单的问题,没有使用 SGD,每次都是将数据录入
for step in range(epochs):optimizer.zero_grad()hypothesis=model(x) # Our hypothesiscost=criterion(hypothesis,y)cost.backward()optimizer.step()cost_h[step]=cost.data.numpy()print(step,'Loss:',cost.data.numpy(),'\nPredict:\n',hypothesis.data.numpy())for name, param in model.named_parameters():if param.requires_grad:print(name, param.data)plt.plot(cost_h)
plt.show()

执行结果

使用 Python 运行上述程序,结果如下:

weight tensor([[-0.0980,  0.5064,  0.4115]])
bias tensor([-0.1257])

在这里插入图片描述

因为模型的定义是:

model=torch.nn.Linear(3,1,bias=True)

也就是包含了三个参数和一个偏置,最终机器学习得到的公式就是:
v = − 0.0980 x + 0.5064 y + 0.4115 z − 0.1257 v = -0.0980x + 0.5064y + 0.4115z -0.1257 v=0.0980x+0.5064y+0.4115z0.1257

我们就可以由某一天的黄金、石油、原煤的价格,来预测这支股票的价格。


  1. 因变量(dependent variable)函数中的专业名词,也叫函数值。函数关系式中,某些特定的数会随另一个(或另几个)会变动的数的变动而变动,就称为因变量。如:Y=f(X)。此式表示为:Y随X的变化而变化。Y是因变量,X是自变量。 ↩︎


http://www.ppmy.cn/news/1570715.html

相关文章

[手机Linux] onepluse6T 系统重新分区

一,刷入TWRP 1. 电脑下载 Fastboot 工具(解压备用)和对应机型 TWRP(.img 后缀文件,将其放入前面解压的文件夹里) 或者直接这里下载:TWRP 2. 将手机关机,长按音量上和下键 开机键 进入 fastbo…

Ubuntu 下 nginx-1.24.0 源码分析 - ngx_get_options函数

声明 就在 main函数所在的 nginx.c 中&#xff1a; static ngx_int_t ngx_get_options(int argc, char *const *argv); 实现 static ngx_int_t ngx_get_options(int argc, char *const *argv) {u_char *p;ngx_int_t i;for (i 1; i < argc; i) {p (u_char *) argv[i]…

Android的MQTT客户端实现

在 Android 平台上实现 MQTT 客户端的完整技术方案&#xff0c;涵盖基础实现、安全连接、性能优化和最佳实践&#xff1a; 一、技术选型与依赖配置 推荐库 Eclipse Paho Android Service&#xff08;官方维护&#xff0c;支持后台运行&#xff09; gradle 复制 // build.gradl…

MySQL性能优化MySQL索引失效的13种隐蔽场景排查及解决方法

在使用 MySQL 数据库时,索引是提高查询性能的重要手段。然而,如果索引使用不当,可能会导致索引失效,从而影响数据库的性能。本文将介绍 MySQL 索引失效场景,并通过实际案例进行详细分析,帮助你更好地理解和避免这些问题。 一、索引失效的13种隐蔽场景 1. 使用 OR 条件查…

Java Stream API:高效数据处理的利器引言

Java Stream API&#xff1a;高效数据处理的利器引言 在 Java 编程中&#xff0c;数据处理是一项极为常见且关键的任务。传统的 for 循环在处理数据集合时&#xff0c;往往会导致代码变得冗长、复杂&#xff0c;这不仅增加了代码的编写难度&#xff0c;还降低了代码的可读性和…

profinet转ModbusTCP网关,助机器人“掀起”工业智能的惊涛骇浪

在现代汽车制造过程中&#xff0c;生产设备的精确控制与实时监测是确保产品质量和生产效率的关键。某汽车制造厂在其生产线上应用了可编程逻辑控制器&#xff08;PLC&#xff09;和压力传感器&#xff0c;这两种设备分别使用稳联技术Profinet和ModbusTCP协议&#xff08; WL-A…

【DeepSeek论文翻译】DeepSeek-R1: 通过强化学习激励大型语言模型的推理能力

目录 摘要 1. 引言 2. 方法 2.1. 概述 2.2. DeepSeek-R1-Zero&#xff1a;在基础模型上进行强化学习 2.2.1. 强化学习算法 2.2.2. 奖励建模 2.2.3. 训练模板 2.2.4. DeepSeek-R1-Zero 的性能、自我进化过程和顿悟时刻 2.3. DeepSeek-R1&#xff1a;具有冷启动的强化学…

Ollama 本地部署 体验 deepseek

下载安装ollama,选择模型 进行部署 # 管理员命令行 执行 ollama run deepseek-r1:70b浏览器访问http://ip:11434/ 返回 Ollama is runninghttp://ip:11434/v1/models 返回当前部署的模型数据 下载安装CherryStudio&#xff0c;本地对话UI 客户端 在设置中 修改API地址&#x…