自定义数据集 使用paddlepaddle框架实现逻辑回归

news/2025/2/5 20:54:27/

导入必要的库

import numpy as np
import paddle
import paddle.nn as nn

数据准备:

seed=1
paddle.seed(seed)# 1.散点输入 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
#转化为数组
data=np.array(data)
# 提取x 和y
x_data=data[:,0]
y_data=data[:,1]
#转成张量 转成paddlepaddle张量
x_train=paddle.to_tensor(x_data,dtype=paddle.float32)
y_train=paddle.to_tensor(y_data,dtype=paddle.float32)

定义模型:

class LinearModel(nn.Layer):def __init__(self):super(LinearModel,self).__init__()self.linear=nn.Linear(1,1)def forward(self,x):x=self.linear(x)return x
#定义模型的对象
model=LinearModel()

损失函数和优化器:

#3.1损失函数
criterion=paddle.nn.MSELoss()
#3.2 优化器
optimizer=paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

模型训练和保存:

epochs=500
final_checkpoint={}
for epoch in range(1,epochs+1):#前向传播#unsqueeze()扩展一维y_prd=model(x_train.unsqueeze(1))loss=criterion(y_prd.squeeze(1),y_train)#清除之前计算的梯度optimizer.clear_grad()#自动计算梯度loss.backward()#更新参数optimizer.step()# 5.显示频率的设置if epoch % 10==0 or epoch==1:#可以使用float(loss)或者 loss.numpy()会报警告print(f"epoch:{epoch},loss:{float(loss)}")#添加检查点程序if epoch==epochs:#把迭代次数写入final_checkpoint['epoch']=epoch#把训练损失写入final_checkpoint['loss']=loss#基础API模型的保存
paddle.save(model.state_dict(),'./基础API/model.pdparams')
#保存检查点checkpoint信息 是序列化的文件
paddle.save(final_checkpoint, "./基础API/final_checkpoint.pkl")

模型加载及预测:

#基础API模型的加载
model_state_dict=paddle.load('./基础API/model.pdparams')
# optimizer_state_dict=paddle.load('./基础API/optimizer.pdopt')
final_checkpoint_state_dict=paddle.load('./基础API/final_checkpoint.pkl')
print(final_checkpoint_state_dict)#模型和参数联系起来
model.set_state_dict(model_state_dict)#训练 评估 和推理
# 模型验证模式
model.eval()
#使用TensorDateset 和DateLoader封装
dataloader_test=DataLoader(TensorDataset([paddle.to_tensor([1.5],dtype=paddle.float32)]),batch_size=1)#迭代
for x_test in dataloader_test:predict=model(x_test[0])print(predict)

结果展示:


http://www.ppmy.cn/news/1569601.html

相关文章

DeepSeek蒸馏模型:轻量化AI的演进与突破

目录 引言 一、知识蒸馏的技术逻辑与DeepSeek的实践 1.1 知识蒸馏的核心思想 1.2 DeepSeek的蒸馏架构设计 二、DeepSeek蒸馏模型的性能优势 2.1 效率与成本的革命性提升 2.2 性能保留的突破 2.3 场景适应性的扩展 三、应用场景与落地实践 3.1 智能客服系统的升级 3.2…

【Qt】06-对话框

对话框 前言一、模态和非模态对话框1.1 概念1.2 模态对话框1.2.1 代码QAction类 1.2.2 模态对话框运行分析 1.3 非模态对话框1.3.1 代码局部变量和成员变量setAttribute 类 1.3.2 现象解释 二、标准对话框2.1 提示对话框 QMessageBox2.1.1 现象及解释 2.2 问题对话框2.2.1 现象…

TCP UDP Service Model

主机A的TCP层可以通过发送FIN消息来关闭链接,主机B确认A不再有数据发送,并停止从A接收新数据。 B完成向A发送数据,并发送自己的FIN消息,告知A它们可以关闭链接。 主机A通过发送ACK作为回应,确认链接现已关闭。 &…

MacBook Pro(M1芯片)Qt环境配置

MacBook Pro(M1芯片)Qt环境配置 1、准备 试图写一个跨平台的桌面应用,此时想到了使用Qt,于是开始了搭建开发环境~ 在M1芯片的电脑上安装,使用brew工具比较方便 Apple Silicon(ARM/M1&#xf…

二叉树——429,515,116

今天继续做关于二叉树层序遍历的相关题目,一共有三道题,思路都借鉴于最基础的二叉树的层序遍历。 LeetCode429.N叉树的层序遍历 这道题不再是二叉树了,变成了N叉树,也就是该树每一个节点的子节点数量不确定,可能为2&a…

deepseek-r1(Mac版 安装教程)

文章目录 deepseek-r1安装教程(Mac)1. 安装ollama2. 本地下载对应的模型3. 使用3.1 终端直接使用3.2 网页使用 deepseek-r1安装教程(Mac) 1. 安装ollama 如果之前没有安装过ollama的,需要在ollama官网下载对应系统的o…

chrome浏览器chromedriver下载

chromedriver 下载地址 https://googlechromelabs.github.io/chrome-for-testing/ 上面的链接有和当前发布的chrome浏览器版本相近的chromedriver 实际使用感受 chrome浏览器会自动更新,可以去下载最新的chromedriver使用,自动化中使用新的chromedr…

【论文复现】粘菌算法在最优经济排放调度中的发展与应用

目录 1.摘要2.黏菌算法SMA原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了一种改进粘菌算法(ISMA),并将其应用于考虑阀点效应的单目标和双目标经济与排放调度(EED)问题。为提升传统粘菌算法&#xf…