使用 PyTorch 实现线性回归:从零开始的完整指南

server/2025/2/1 21:34:49/

在机器学习中,线性回归是最基础且广泛使用的算法之一。它通过拟合数据点之间的线性关系,帮助我们理解和预测变量之间的关系。本文将通过一个简单的例子,展示如何使用 PyTorch 框架实现线性回归,并对自定义数据集进行拟合。

1. 线性回归简介

线性回归的目标是找到一个线性方程 y=wx+b,其中 w 是斜率,b 是截距,使得该方程能够尽可能地拟合给定的数据点。在实际应用中,我们通常使用最小二乘法来最小化预测值与真实值之间的误差。

2. 准备数据

首先,我们需要准备一个简单的数据集。在这个例子中,我们将使用一个包含 10 个数据点的自定义数据集:

data = [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]
]

这些数据点表示输入特征 x 和目标变量 y 之间的关系。我们将使用 PyTorch 的张量(Tensor)来存储和处理这些数据。

3. 构建线性回归模型

接下来,我们需要定义一个线性回归模型。在 PyTorch 中,可以通过继承 nn.Module 来定义一个自定义模型。我们将使用一个简单的线性层来实现这个模型:

class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()self.layers = nn.ModuleList([nn.Linear(1, 1)])def forward(self, x):for layer in self.layers:x = layer(x)return x

这个模型包含一个线性层,其输入维度为 1,输出维度也为 1,正好符合我们的问题需求。

4. 定义损失函数和优化器

为了训练模型,我们需要定义一个损失函数和一个优化器。在这里,我们使用均方误差(MSE)作为损失函数,使用随机梯度下降(SGD)作为优化器:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

5. 训练模型

现在,我们可以开始训练模型了。我们将数据集输入模型,计算损失,并通过反向传播更新模型参数。以下是完整的训练代码:

epochs = 500
for n in range(1, epochs + 1):y_pred = model(x_train.unsqueeze(1))loss = criterion(y_pred.squeeze(1), y_train)optimizer.zero_grad()loss.backward()optimizer.step()if n % 10 == 0 or n == 1:print(f"Epoch: {n}, Loss: {loss.item():.4f}")

在每个 epoch 中,我们计算模型的预测值,计算损失,并通过 loss.backward() 计算梯度,最后通过 optimizer.step() 更新模型参数。

6. 可视化结果

训练完成后,我们可以通过绘制原始数据点和拟合的直线来直观地展示模型的效果。以下是完整的可视化代码:

plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

运行上述代码后,你将看到如下图像:

从图中可以看出,拟合的直线能够较好地反映数据点之间的线性关系。

7. 总结

通过本文的介绍,你已经学会了如何使用 PyTorch 实现线性回归,并对自定义数据集进行拟合。线性回归虽然简单,但在许多实际问题中都非常有效。希望这篇文章能够帮助你更好地理解和应用线性回归模型。


代码完整版

以下是完整的代码,供你参考和使用:

import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt# 设置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号# 定义输入数据
data = [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]
]# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]# 将 x_data 和 y_data 转化成 tensor
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)# 定义损失函数
criterion = nn.MSELoss()# 定义线性回归模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()self.layers = nn.ModuleList([nn.Linear(1, 1)])def forward(self, x):for layer in self.layers:x = layer(x)return xmodel = LinearModel()# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 500
for n in range(1, epochs + 1):y_pred = model(x_train.unsqueeze(1))loss = criterion(y_pred.squeeze(1), y_train)optimizer.zero_grad()loss.backward()optimizer.step()if n % 10 == 0 or n == 1:print(f"Epoch: {n}, Loss: {loss.item():.4f}")# 绘制图像
# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()


http://www.ppmy.cn/server/164171.html

相关文章

【漫话机器学习系列】066.贪心算法(Greedy Algorithms)

贪心算法(Greedy Algorithms) 贪心算法是一种逐步构建解决方案的算法,每一步都选择当前状态下最优的局部选项(即“贪心选择”),以期望最终获得全局最优解。贪心算法常用于解决最优化问题。 核心思想 贪心选…

Unity游戏(Assault空对地打击)开发(2) 基础场景布置

目录 导入插件 文件夹整理 场景布置 山地场景 导入插件 打开【My Assets】(如果你刚进行上篇的操作,该窗口默认已经打开了)。 找到添加的几个插件,点击Download并Import x.x to...。 文件夹整理 我们的目录下多了两个文件夹&a…

Kotlin 2.1.0 入门教程(九)

类型检查和转换 在 Kotlin 中,可以执行类型检查以在运行时检查对象的类型。类型转换能够将对象转换为不同的类型。 is 和 !is 操作符 要执行运行时检查以确定对象是否符合给定类型,请使用 is 操作符或其否定形式 !is。 if (obj is String) {print(ob…

MySQL查询优化(三):深度解读 MySQL客户端和服务端协议

如果需要从 MySQL 服务端获得很高的性能,最佳的方式就是花时间研究 MySQL 优化和执行查询的机制。一旦理解了这些,大部分的查询优化是有据可循的,从而使得整个查询优化的过程更有逻辑性。下图展示了 MySQL 执行查询的过程: 客户端…

提示词工程

1、什么构成了一个好的提示 提示:输入给AI的问题或指令 好的提示能极大地提高AI的理解和执行的效率,让AI提供更准确和有用的回答。 提示工程(Prompt Engineering):研究如何写出好的提示 提示工程原则: …

阿里云域名备案

一、下载阿里云App 手机应用商店搜索"阿里云",点击安装。 二、登录阿里云账号 三、打开"ICP备案" 点击"运维"页面的"ICP备案"。 四、点击"新增网站/App" 若无备案信息,则先新增备案信息。 五、开始备案

OPENPPP2 —— VMUX_NET 多路复用原理剖析

在阅读本文之前,必先了解以下几个概念: 1、MUX(Multiplexer):合并多个信号到单一通道。 2、DEMUX(Demultiplexer):从单一通道分离出多个信号。 3、单一通道,可汇聚多个…

手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion(原理介绍)

手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion(原理介绍) 目录 手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion(原理介绍)DDPM 原理图Stable Diffusion 原理Stable Diffusion的原理解释Stable Diffusion 和 Diffus…