自定义数据集 使用paddlepaddle框架实现逻辑回归

devtools/2025/2/5 13:14:36/

导入必要的库

import numpy as np
import paddle
import paddle.nn as nn

数据准备:

seed=1
paddle.seed(seed)# 1.散点输入 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
#转化为数组
data=np.array(data)
# 提取x 和y
x_data=data[:,0]
y_data=data[:,1]
#转成张量 转成paddlepaddle张量
x_train=paddle.to_tensor(x_data,dtype=paddle.float32)
y_train=paddle.to_tensor(y_data,dtype=paddle.float32)

定义模型:

class LinearModel(nn.Layer):def __init__(self):super(LinearModel,self).__init__()self.linear=nn.Linear(1,1)def forward(self,x):x=self.linear(x)return x
#定义模型的对象
model=LinearModel()

损失函数和优化器:

#3.1损失函数
criterion=paddle.nn.MSELoss()
#3.2 优化器
optimizer=paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

模型训练和保存:

epochs=500
final_checkpoint={}
for epoch in range(1,epochs+1):#前向传播#unsqueeze()扩展一维y_prd=model(x_train.unsqueeze(1))loss=criterion(y_prd.squeeze(1),y_train)#清除之前计算的梯度optimizer.clear_grad()#自动计算梯度loss.backward()#更新参数optimizer.step()# 5.显示频率的设置if epoch % 10==0 or epoch==1:#可以使用float(loss)或者 loss.numpy()会报警告print(f"epoch:{epoch},loss:{float(loss)}")#添加检查点程序if epoch==epochs:#把迭代次数写入final_checkpoint['epoch']=epoch#把训练损失写入final_checkpoint['loss']=loss#基础API模型的保存
paddle.save(model.state_dict(),'./基础API/model.pdparams')
#保存检查点checkpoint信息 是序列化的文件
paddle.save(final_checkpoint, "./基础API/final_checkpoint.pkl")

模型加载及预测:

#基础API模型的加载
model_state_dict=paddle.load('./基础API/model.pdparams')
# optimizer_state_dict=paddle.load('./基础API/optimizer.pdopt')
final_checkpoint_state_dict=paddle.load('./基础API/final_checkpoint.pkl')
print(final_checkpoint_state_dict)#模型和参数联系起来
model.set_state_dict(model_state_dict)#训练 评估 和推理
# 模型验证模式
model.eval()
#使用TensorDateset 和DateLoader封装
dataloader_test=DataLoader(TensorDataset([paddle.to_tensor([1.5],dtype=paddle.float32)]),batch_size=1)#迭代
for x_test in dataloader_test:predict=model(x_test[0])print(predict)

结果展示:


http://www.ppmy.cn/devtools/156279.html

相关文章

防火墙安全策略实验

一、实验拓扑图及实验要求 实验要求: 1、VLAN2属于办公区;VLAN 3属于生产区。 2、办公区PC在工作日时间(周一至周五,早8到晚6)可以正常访问0A Server,其他时间不允许。 3、办公区可以在任意时刻访问web Server 4、生产…

嵌入式知识点总结 操作系统 专题提升(四)-上下文

针对于嵌入式软件杂乱的知识点总结起来,提供给读者学习复习对下述内容的强化。 目录 1.上下文有哪些?怎么理解? 2.为什么会有上下文这种概念? 3.什么情况下进行用户态到内核态的切换? 4.中断上下文代码中有哪些注意事项? 5.请问线程需要保存哪些…

JavaScript面向对象编程:Prototype与Class的对比详解

JavaScript面向对象编程:Prototype与Class的对比详解 JavaScript面向对象编程:Prototype与Class的对比详解引言什么是JavaScript的面向对象编程?什么是Prototype?Prototype的定义Prototype的工作原理示例代码优点缺点 什么是JavaS…

UE5 蓝图学习计划 - Day 11:材质与特效

在游戏开发中,材质(Material)与特效(VFX) 是提升视觉体验的关键元素。Unreal Engine 5 提供了强大的 材质系统 和 粒子系统(Niagara),让开发者可以通过蓝图控制 动态材质、光效变化、…

Linux stat 命令使用详解

简介 stat 命令打印文件和文件系统的详细信息。该工具提供有关所有者是谁、修改日期、访问权限、大小、类型等信息。 该实用程序对于故障排除、在更改文件之前获取有关文件的信息以及例行文件和系统管理任务至关重要。 基本语法 stat [arguments] [filename]常用选项 -L, -…

99.23 金融难点通俗解释:小卖部经营比喻PPI(生产者物价指数)vsCPI(消费者物价指数)

目录 0. 承前1. 简述:价格指数对比2. 比喻:两大指数对比2.1 简单对比2.2 生动比喻 3. 实际应用3.1 价格传导现象 4. 总结5. 有趣的对比6. 数据获取实现代码7. 数据可视化实现代码 0. 承前 本文主旨: 本文使用小卖部比喻PPI和CPI,…

Redis基础(二)——通用命令与五大基本数据类型

目录 一、Redis数据结构基本介绍 二、Redis通用命令 1.查看通用命令 2.KEYS:查看符合模板的所有key 3.DEL:删除指定的Key 4.lEXISTS:判断key是否存在 5.lEXPIRE:给一个key设置有效期,有效期到期时该key会被自…

redis实现延迟任务

定时任务:有固定周期,有明确的触发时间 延迟任务:没有固定的开始时间,由一个事件触发,在这个事件触发之后的一段时间内触发另一个事件,任务可以立即执行,也可以延迟执行。 场景1:订…