PyTorch基本使用-线性回归案例

server/2024/12/15 20:36:13/

文章目录

    • 1. 训练模型步骤
    • 2. 训练模型API
    • 3. 训练模型

学习目标:掌握PyTorch构建线性回归模型相关API

1. 训练模型步骤

我们使用 PyTorch 的各个组件来构建线性回归的实现。在pytorch中进行模型构建的整个流程一般分为四个步骤:

  1. 准备训练数据集
  2. 构建要使用的模型
  3. 设置损失函数及优化器
  4. 训练模型
    在这里插入图片描述

2. 训练模型API

  • 使用 PyTorch 的 nn.MSELoss()代替自定义的平方损失函数
  • 使用 PyTorch 的 data.DataLoader代替自定义的数据加载器
  • 使用 PyTorch 的 optim.SGD代替自定义的优化器
  • 使用 PyTorch 的 nn.Linear代替自定义的假设函数

3. 训练模型

具体代码:

# 导入相关模块
import torch
from sympy import false
from torch.utils.data import  TensorDataset # 构造数据集对象
from torch.utils.data import DataLoader # 数据加载器
from  torch import  nn # nn模块中有平方损失函数和假设函数
from torch import optim # optim 模块中有优化器函数
from sklearn.datasets import make_regression # 创建线性回归模型数据集
import  matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 数据集构建
def create_dataset():x, y, coef = make_regression(n_samples=100,n_features=1,noise=10,coef=True,bias=1.5,random_state=0)x = torch.tensor(x)y = torch.tensor(y)return x, y, coefdef train():# 1. 构造数据集x, y, coef = create_dataset()# 构造数据集对象dataset = TensorDataset(x, y)# 构造数据加载器# dataset=:数据集对象# batch_size=:批量训练样本数据# shuffle=:样本数据是否进行乱序dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True)# 2. 构造模型# in_features=:输入张量的大小size# out_features=:输出张量的大小sizemodel = nn.Linear(in_features=1, out_features=1)# 3.设置损失函数和优化器# 构造平方损失函数criterion = nn.MSELoss()# 构造优化函数optimizer = optim.SGD(model.parameters(), lr=1e-2)# 4.训练模型epochs = 100# 损失变化loss_epochs = []total_loss=0.0train_sample=0.0for _ in range(epochs):for train_x,train_y in dataloader:# 将一个batch的训练数据送入模型y_pred = model(train_x.type(torch.float32))# 计算损失函数值loss = criterion(y_pred, train_y.reshape(-1, 1).type(torch.float32))total_loss += loss.item()train_sample += len(train_x)# 梯度清零optimizer.zero_grad()# 自动微分(反向传播)loss.backward()# 更新参数optimizer.step()# 获取每次batch的损失loss_epochs.append(total_loss/train_sample)# 绘制损失变化曲线plt.plot(range(epochs), loss_epochs)plt.title('损失变化曲线')plt.grid()plt.show()# 绘制拟合直线plt.scatter(x,y)x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * model.weight + model.bias for v in x])y2 = torch.tensor([v * coef +1.5 for v in x])plt.plot(x, y1,label='训练')plt.plot(x, y2,label='真实')plt.grid()plt.legend()plt.show()if __name__ == '__main__':# 生成数据# x, y, coef = create_dataset()# # 绘制数据的真实线性回归结果# plt.scatter(x,y)# x = torch.linspace(x.min(), x.max(),1000)# y1 = torch.tensor([v * coef + 1.5 for v in x])# plt.plot(x,y1,label='real')# plt.grid()# plt.legend()# plt.show()train()

输出结果:
在这里插入图片描述

在这里插入图片描述


http://www.ppmy.cn/server/150432.html

相关文章

《智能体开发实战(高阶)》四、系统化的日志周报智能体开发计划

智能体扩展与完善规划 为了将前几个章节的智能体逐步扩展为支持整个公司团队使用的高效工具,以下是分阶段的完善与扩写规划。每个阶段旨在提升功能覆盖范围、处理能力和用户体验,并为企业提供实际价值。 阶段一:基础功能完善 目标:巩固现有功能,提升健壮性和适用性。 支…

android 底层硬件通知webview 技术—未来之窗行业应用跨平台架构

String 未来之窗反向js2 "javascript:" "东方仙盟技术" "(\"nfc_reader\"," 未来之窗NFC ")"; cwpd_Web.evaluateJavascript(未来之窗反向js2, new ValueCallback<String>() { …

「Mac玩转仓颉内测版50」小学奥数篇13 - 动态规划入门

本篇将通过 Python 和 Cangjie 双语介绍动态规划的基本概念&#xff0c;并解决一个经典问题&#xff1a;斐波那契数列。学生将学习如何使用动态规划优化递归计算&#xff0c;并掌握编程中的重要算法思想。 关键词 小学奥数Python Cangjie动态规划斐波那契数列 一、题目描述 …

docker-4.迁移存储目录

docker pull 拉取镜像时候磁盘空间满,迁移/var/lib/docker目录 目录 1. 清理Docker占用的磁盘空间2.迁移 /var/lib/docker 目录3.开机自动挂载文件/etc/fstab4.docker国内镜像源1. 清理Docker占用的磁盘空间 清理空间: Docker System命令, 在《谁用光了磁盘?Docker System…

路由介绍.

RIB和FIB Routing Information Base&#xff08;RIB&#xff09;&#xff0c;即路由信息库&#xff0c;是存储在路由器或联网计算机中的一个电子表格或类数据库&#xff0c;它保存着指向特定网络地址的路径信息&#xff0c;包括路径的路由度量值。RIB的主要目标是实现路由协议…

docker容器内部启动jupyter notebook但是宿主机无法访问的解决方法

目录 1.问题2.解决方法 1.问题 在docker容器内启动了jupyter notebook&#xff0c;在宿主机内用如下的url无法访问 http://localhost:8888 http://127.0.0.1:8888 启动方法&#xff1a; jupyter notebook 2.解决方法 启动方法加上选项[ --ip‘*’]或者[–ip‘0.0.0.0’] 即启…

Rust 编程语言介绍

一、基本介绍 Rust 是一种系统编程语言&#xff0c;由 Mozilla 研究院开发。它的设计目标是在保证高性能的同时&#xff0c;提供内存安全和线程安全。相比C和C语言具有下面几个特点&#xff1a; 内存安全&#xff1a;在传统的编程语言如 C 和 C 中&#xff0c;手动管理内存可…

clipboard----封装复制组件

Clipboard.js 是一个轻量级的 JavaScript 库&#xff0c;旨在帮助开发者轻松地实现将文本复制到剪贴板的功能。它不依赖 Flash 或其他外部库&#xff0c;并且提供了一种简单的方式来响应用户的复制行为。Clipboard.js 支持绑定到任何元素&#xff08;如按钮、图片等&#xff09…