使用PaddlePaddle实现逻辑回归:从训练到模型保存与加载

embedded/2025/2/4 9:23:15/

1. 引入必要的库

首先,需要引入必要的库。PaddlePaddle用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import paddle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用PaddlePaddle来构建逻辑回归模型。

# 构建逻辑回归模型
class LogisticRegression(paddle.nn.Layer):def __init__(self, num_features):super(LogisticRegression, self).__init__()self.linear = paddle.nn.Linear(num_features, 1)def forward(self, x):return paddle.sigmoid(self.linear(x))# 初始化模型
num_features = X.shape[1]
model = LogisticRegression(num_features)

4. 定义损失函数和优化器

使用二元交叉熵损失函数和随机梯度下降(SGD)优化器。

# 定义损失函数和优化器
criterion = paddle.nn.BCELoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

5. 训练模型

使用自定义数据集训练模型。

# 将数据转换为PaddlePaddle的张量
X_tensor = paddle.to_tensor(X)
y_tensor = paddle.to_tensor(y.reshape(-1, 1))# 训练模型
num_epochs = 100
batch_size = 32
for epoch in range(num_epochs):for i in range(0, len(X), batch_size):X_batch = X_tensor[i:i+batch_size]y_batch = y_tensor[i:i+batch_size]# 前向传播outputs = model(X_batch)loss = criterion(outputs, y_batch)# 反向传播和优化loss.backward()optimizer.step()optimizer.clear_grad()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.numpy()}')

6. 保存模型

训练完成后,可以使用PaddlePaddle的save方法保存模型。

# 保存模型
paddle.save(model.state_dict(), 'logistic_regression_model.pdparams')

7. 加载模型并进行预测

在需要时,可以使用PaddlePaddle的load方法加载模型,并进行预测。

# 加载模型
model = LogisticRegression(num_features)
model.set_state_dict(paddle.load('logistic_regression_model.pdparams'))
model.eval()# 进行预测
X_test = paddle.to_tensor(X[:5])
predictions = model(X_test)
predicted_labels = (predictions > 0.5).astype(int)print("Predicted Labels:", predicted_labels.numpy().flatten())

8. 结果可视化

如果需要,可以绘制训练过程中的损失变化曲线,以帮助理解模型的性能。

# 这里假设我们在训练过程中记录了损失值
# plt.plot(loss_values, label='Loss')
# plt.title('Model Loss')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()


http://www.ppmy.cn/embedded/159425.html

相关文章

61.异步编程1 C#例子 WPF例子

和普通的任务绑定不太相同的部分如下: public MainWindowViewModel(){FetchUserInfoCommand new RelayCommand(async (param) > await FetchUserInfoAsync());}private async Task FetchUserInfoAsync(){// 模拟异步操作,比如网络请求await Task.Del…

【14】WLC3504 HA配置实例

1.概述 本文档使用 Cisco WLC 3504 实现无线控制器的高可用性。这里所指的HA是指WLC设备box-to-box的冗余。换句话说,即1:1的设备冗余,其中一个 WLC 将处于Active活动状态,而第二个 WLC 将处于Standby-hot热待机状态,通过RP冗余端口持续监控活动 WLC 的运行状况。两个 WLC…

【硬件测试】基于FPGA的QPSK+帧同步系统开发与硬件片内测试,包含高斯信道,误码统计,可设置SNR

目录 1.算法仿真效果 2.算法涉及理论知识概要 2.1QPSK 2.2 帧同步 3.Verilog核心程序 4.开发板使用说明和如何移植不同的开发板 5.完整算法代码文件获得 1.算法仿真效果 本文是之前写的文章 《基于FPGA的QPSK帧同步系统verilog开发,包含testbench,高斯信道,误码统计,可…

从理论到实践:Linux 进程替换与 exec 系列函数

个人主页:chian-ocean 文章专栏-Linux 前言: 在Linux中,进程替换(Process Substitution)是一个非常强大的特性,它允许将一个进程的输出直接当作一个文件来处理。这种技术通常用于Shell脚本和命令行操作中…

Java:日期时间范围的处理

java判断时间是否在某个时间段内_java判断一个时间是否在某个时间段-CSDN博客 java时间处理--判断当前时间是否在一个时间区间内_java_xtz......-腾讯云开发者社区 //需求:你发布了一个二手商品信息,其他用户看到后给你商品留言,如果留言时…

微信登录模块封装

文章目录 1.资质申请2.combinations-wx-login-starter1.目录结构2.pom.xml 引入okhttp依赖3.WxLoginProperties.java 属性配置4.WxLoginUtil.java 后端通过 code 获取 access_token的工具类5.WxLoginAutoConfiguration.java 自动配置类6.spring.factories 激活自动配置类 3.com…

3 Spark SQL

Spark SQL 1. 数据分析方式2. SparkSQL 前世今生3. Hive 和 SparkSQL4. 数据分类和 SparkSQL 适用场景1) 结构化数据2) 半结构化数据3) 总结 5. Spark SQL 数据抽象1) DataFrame2) DataSet3) RDD、DataFrame、DataSet 的区别4) 总结 6. Spark SQL 应用1) 创建 DataFrame/DataSe…

Node 处理客户端不同的请求方法

一、使用 http 模块处理请求方法 1. 创建 HTTP 服务器 const http require("http");const server http.createServer((req, res) > {// 处理不同的请求方法switch (req.method) {case "GET":handleGetRequest(req, res);break;case "POST"…