pytorch实现单层线性回归模型

server/2024/9/25 21:19:20/

文章目录

    • 简述
      • 代码重构要点
    • 数学模型、运行结果
    • 数据构建与分批
    • 模型封装
    • 运行测试

简述

python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客

上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。

代码重构要点

1.nn.Moudle

torch.nn.Module的继承、nn.Sequentialnn.Linear
torch.nn — PyTorch 2.4 documentation

对于nn.Sequential的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:

在这里插入图片描述

nn.Sequential可以说是把图中标注的代码封装起来了,并且可以放多层。

2.torch.optim优化器

本例中使用随机梯度下降torch.optim.SGD()
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation

3.数据构建与数据加载

data.TensorDatasetdata.DataLoader,之前为了实现数据分批,手动实现了data_iter,现在可以直接调用pytorchdata.DataLoader

对于data.DataLoader的参数num_workers,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。

数学模型、运行结果

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

运行结果:

在这里插入图片描述

在这里插入图片描述

数据构建与分批

def build_data(weights, bias, num_examples):  x = torch.randn(num_examples, len(weights))  y = x.matmul(weights) + bias  # 给y加个噪声  y += torch.randn(1)  return x, y  def load_array(data_arrays, batch_size, num_workers=0, is_train=True):  """构造一个PyTorch数据迭代器"""  dataset = data.TensorDataset(*data_arrays)  return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)

模型封装

class TorchLinearNet(torch.nn.Module):  def __init__(self):  super(TorchLinearNet, self).__init__()  model = nn.Sequential(Linear(in_features=2, out_features=1))  self.model = model  self.criterion = nn.MSELoss()  def predict(self, x):  return self.model(x)  def loss(self, y_predict, y):  return self.criterion(y_predict, y)

运行测试

if __name__ == '__main__':  start = time.perf_counter()  true_w1 = torch.rand(2, 1)  true_b1 = torch.rand(1)  x_train, y_train = build_data(true_w1, true_b1, 5000)  net = TorchLinearNet()  print(net)  init_loss = net.loss(net.predict(x_train), y_train)  loss_history = list()  loss_history.append(init_loss.item())  num_epochs = 3  batch_size = 50  learning_rate = 0.01  dataloader_workers = 6  data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True)  optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)  for epoch in range(num_epochs):  # running_loss = 0.0  for x, y in data_loader:  y_pred = net.predict(x)  loss = net.loss(y_pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  # running_loss = running_loss + loss.item()  loss_history.append(loss.item())  end = time.perf_counter()  print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n")  plt.title("pytorch实现单层线性回归模型", fontproperties="STSong")  plt.xlabel("epoch")  plt.ylabel("loss")  plt.plot(loss_history, linestyle='dotted')  plt.show()  print(f'初始损失值:{init_loss}')  print(f'最后一次损失值:{loss_history[-1]}\n')  print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  print(f'预测参数:{net.model.state_dict()}')

http://www.ppmy.cn/server/101576.html

相关文章

Ubuntu安装cuda

本文详细介绍了在 Ubuntu 系统上安装 CUDA 的全过程。从安装前的系统要求和准备工作,到具体的安装步骤,包括下载 CUDA 安装文件、处理依赖关系、执行安装命令以及配置环境变量等。旨在为需要在 Ubuntu 中安装 CUDA 以进行深度学习、图形计算等工作的用户…

代码随想录算法训练营Day39 | 322. 零钱兑换 | 279.完全平方数 | 139.单词拆分

今日任务 322. 零钱兑换 题目链接&#xff1a; https://leetcode.cn/problems/coin-change/description/题目描述&#xff1a; Code class Solution { public:int coinChange(vector<int>& coins, int amount) {int n coins.size();// vector<vector<int…

c# 什么是扩展方法

官方解释 扩展方法使你能够向现有类型“添加”方法&#xff0c;而无需创建新的派生类型、重新编译或以其他方式修改原始类型。 扩展方法是一种静态方法&#xff0c;但可以像扩展类型上的实例方法一样进行调用。 对于用 C#、F# 和 Visual Basic 编写的客户端代码&#x…

抽象代数精解【9】

文章目录 流密码密码体制概述唯吉尼亚密码一、历史与背景二、加密算法三、特点与应用四、破译方法五、原理概述加密过程解密过程注意事项 流密码理论解释一、定义与原理二、特点与优势三、工作原理四、应用实例五、安全性与限制 RC4算法一、算法概述二、算法原理三、算法特点四…

只有IP如何实现https访问

IP也是访问网站的一种方式&#xff0c;现在有很多网站并未绑定域名&#xff0c;而是通过IP直接访问的。 但是域名访问网站的方式会更多一些&#xff0c;主要还是因为域名相较于IP数字要更加好记&#xff0c;所以域名绑定网站的情况会更多。 随着现在网络安全意识的逐渐提升&a…

ios app包应用签名证书指纹SHA256值

获取应用签名证书的指纹&#xff0c;首先要获取给app签名的证书&#xff0c;然后从证书里面获取SHA256签名&#xff0c;具体步骤如下 1 获取iOS app签名证书指纹SHA256值2 导出p12文件3 获取证书指纹SHA256值4 完成 操作步骤及代码 步骤1&#xff1a;首先&#xff0c;你需要…

python爬虫滑块验证及各种加密函数(基于ddddocr进行的一层封装)

git链接: https://github.com/JOUUUSKA/spider_toolsbox 这里写目录标题 一.识别验证码1、识别英文&#xff0b;数字验证码2、识别滑块验证码3、识别点选验证码 二、下载系列1、下载视频2、下载图片3、下载文本 三、常用加密类型1、AES系列2、DES系列3、RSA系列4、SHA系列5、B…

选对BI解决方案,数据才能驱动成功?奥威BI数据可视化方案深度解析

选对BI解决方案&#xff0c;数据才能驱动成功&#xff1f;奥威BI数据可视化方案深度解析 在当今这个数据爆炸的时代&#xff0c;企业面临着前所未有的机遇与挑战。如何有效利用数据&#xff0c;将其转化为推动业务增长和决策优化的关键力量&#xff0c;成为了每个企业都必须面…