使用 PaddlePaddle 实现逻辑回归:从训练到模型保存与加载

devtools/2025/2/3 9:47:34/

在机器学习中,逻辑回归是一种经典的分类算法,广泛应用于二分类问题。今天,我们将通过一个简单的例子,使用 PaddlePaddle 框架实现逻辑回归模型,并展示如何保存和加载模型,以便进行后续的预测。

1. 简介

逻辑回归是一种线性分类模型,通过学习输入特征与输出标签之间的关系,实现对新数据的分类。PaddlePaddle 是一个开源的深度学习框架,提供了丰富的接口和工具,方便开发者快速实现和部署机器学习模型。

2. 数据准备

为了演示逻辑回归,我们生成了两组二维数据点,分别表示两个不同的类别。以下是数据的定义:

class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])

我们将这两组数据合并,并为它们分配标签(0 和 1),表示不同的类别。然后,将数据转换为 Paddle 的 Tensor 格式,以便用于模型训练。

3. 模型定义

逻辑回归模型的核心是一个线性层,后面接一个 Sigmoid 激活函数。Sigmoid 函数将输出值映射到 (0, 1) 区间,表示属于某个类别的概率。以下是模型的定义:

class LogisticRegression(nn.Layer):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为2,输出为1def forward(self, x):return nn.functional.sigmoid(self.linear(x))

4. 训练模型

为了训练模型,我们需要定义优化器和损失函数。这里我们使用随机梯度下降(SGD)优化器和二元交叉熵损失函数(BCELoss)。训练过程如下:

optimizer = optim.SGD(parameters=model.parameters(), learning_rate=0.1)
loss_fn = nn.BCELoss()epochs = 1000
for epoch in range(epochs):y_pred = model(X)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()optimizer.clear_grad()if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {loss.numpy()}")

在训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。

5. 保存模型

训练完成后,我们需要将模型的参数保存到文件中,以便后续加载和使用。PaddlePaddle 提供了 paddle.save 方法,可以方便地保存模型参数:

paddle.save(model.state_dict(), 'model2.pdparams')

6. 加载模型

在需要使用模型进行预测时,我们可以通过 paddle.load 方法加载保存的模型参数,并将其加载到模型中:

model_state_dict = paddle.load('model2.pdparams')
model.load_dict(model_state_dict)

7. 预测新数据

加载模型后,我们可以对新的数据点进行预测。以下是预测新数据的代码:

new_data = paddle.to_tensor([[2.0, 2.0], [3.5, 3.0]], dtype='float32')
predictions = model(new_data)
print("Predictions:", predictions.numpy())

预测结果是一个概率值,表示数据点属于类别 1 的概率。

8. 完整代码

以下是完整的代码实现:

import paddle
import paddle.nn as nn
import paddle.optimizer as optim
import numpy as np"""使用paddlepaddle框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测"""
# 数据准备
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])# 将数据合并为一个数据集
X = np.vstack((class1_points, class2_points))
y = np.hstack((np.zeros(len(class1_points)), np.ones(len(class2_points))))# 转换为 Paddle 的 Tensor
X = paddle.to_tensor(X, dtype='float32')
y = paddle.to_tensor(y, dtype='float32').reshape([-1, 1])# 定义逻辑回归模型
class LogisticRegression(nn.Layer):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为2,输出为1def forward(self, x):return nn.functional.sigmoid(self.linear(x))# 实例化模型、优化器和损失函数
model = LogisticRegression()
optimizer = optim.SGD(parameters=model.parameters(), learning_rate=0.1)
loss_fn = nn.BCELoss()# 训练模型1000次,每100次展示一下loss
epochs = 1000
for epoch in range(epochs):y_pred = model(X)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()optimizer.clear_grad()if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {loss.numpy()}")# 保存模型
paddle.save(model.state_dict(), 'model2.pdparams')# 加载模型
model_state_dict = paddle.load('model2.pdparams')
model.load_dict(model_state_dict)# 预测新数据
new_data = paddle.to_tensor([[2.0, 2.0], [3.5, 3.0]], dtype='float32')
predictions = model(new_data)
print("Predictions:", predictions.numpy())


http://www.ppmy.cn/devtools/155682.html

相关文章

排查定位jar包大文件

解压 JAR 包: mkdir jar_contents unzip your-jar-file.jar -d jar_contents统计各文件大小: du -ah jar_contents | sort -rh | head -n 20这会列出 JAR 包中最大的文件或目录,方便你定位大文件。 方法 2:使用 jar 工具查看文件…

使用eNSP配置GRE VPN实验

实验拓扑 实验需求 1.按照图示配置IP地址 2.在R1和R3上配置默认路由使公网区域互通 3.在R1和R3上配置GRE VPN,使两端私网能够互相访问,Tunne1口IP地址如图 4.在R1和R3上配置RIPv2来传递两端私网路由 实验步骤 GRE VPN配置方法: 发送端&#x…

【Rust】18.2. 可辩驳性:模式是否会无法匹配

喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 18.2.1. 模式的两种形式 模式有两种形式: 可辩驳的(可失败的&…

新一代搜索引擎,是 ES 的15倍?

Manticore Search介绍 Manticore Search 是一个使用 C 开发的高性能搜索引擎,创建于 2017 年,其前身是 Sphinx Search 。Manticore Search 充分利用了 Sphinx,显着改进了它的功能,修复了数百个错误,几乎完全重写了代码…

C#面向对象(封装)

1.什么是封装? C# 封装 封装 被定义为“把一个或多个项目封闭在一个物理的或者逻辑的包中”。 在面向对象程序设计方法论中,封装是为了防止对实现细节的访问。 抽象和封装是面向对象程序设计的相关特性。 抽象允许相关信息可视化,封装则使开发者实现所…

CVE-2023-38831 漏洞复现:win10 压缩包挂马攻击剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 影响范围 防御措施 复现步骤 环境准备 具体操作 前言 在网络安全这片没有硝烟的战场上,新型漏洞如同隐匿的暗箭,时刻威胁着我们的数字生活。其中,CVE - 2023 - 38831 这个关联 Win10 压缩包挂…

Kafka常见问题之 java.io.IOException: Disk error when trying to write to log

文章目录 Kafka常见问题之 java.io.IOException: Disk error when trying to write to log1. 问题概述2. 问题排查方向(1)磁盘空间不足(2)磁盘 I/O 故障(3)Kafka 日志文件损坏(4)Kaf…

Elasticsearch的索引生命周期管理

目录 说明零、参考一、ILM的基本概念二、ILM的实践步骤Elasticsearch ILM策略中的“最小年龄”是如何计算的?如何监控和调整Elasticsearch ILM策略的性能? 1. **监控性能**使用/_cat/thread_pool API基本请求格式请求特定线程池的信息响应内容 2. **调整…