使用numpy自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预

news/2025/2/6 0:18:34/

1. 引言

逻辑回归(Logistic Regression)是一种常见的分类算法,广泛应用于二分类问题。在本篇博客中,我们将使用numpy生成一个简单的自定义数据集,并使用TensorFlow框架构建和训练逻辑回归模型。训练完成后,我们会保存模型,并演示如何加载保存的模型进行预测。

2. 创建自定义数据集

首先,我们使用numpy生成一个简单的二分类数据集,包含两个特征和对应的标签。标签0表示负类,标签1表示正类。

import numpy as np# 设置随机种子,保证每次运行结果一致
np.random.seed(42)# 生成自定义数据集
X = np.random.rand(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 标签:如果两个特征之和大于1,标签为1,否则为0

这样我们就得到了一个简单的二分类数据集,X是特征矩阵,y是标签。

3. 构建逻辑回归模型

接下来,我们使用TensorFlow来构建逻辑回归模型。逻辑回归本质上是一个线性模型,通过Sigmoid函数输出概率,最终将其转化为二分类标签。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense# 构建模型
model = Sequential([Dense(1, input_dim=2, activation='sigmoid')  # 输入层2个特征,输出层1个节点,sigmoid激活函数
])# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

在这个模型中,我们只用了一个包含一个节点的输出层,使用sigmoid激活函数来输出分类概率。损失函数选择了binary_crossentropy,它是二分类问题常用的损失函数。

4. 训练模型

现在,我们使用生成的数据集来训练模型。

# 训练模型
model.fit(X, y, epochs=50, batch_size=10, verbose=1)

训练过程将进行50个周期,每批次包含10个样本。你可以根据自己的需求调整epochsbatch_size

5. 保存模型

训练完成后,我们将保存模型。TensorFlow提供了方便的保存方法,可以将整个模型(包括模型架构、权重和训练配置)保存在一个文件中。

# 保存模型
model.save('logistic_regression_model.h5')

保存后的模型文件logistic_regression_model.h5将包含模型的所有信息,稍后我们可以重新加载这个模型来进行预测。

6. 加载模型并进行预测

保存的模型可以在后续的工作中重新加载并使用。我们通过tensorflow.keras.models.load_model()来加载保存的模型。

# 加载模型
loaded_model = tf.keras.models.load_model('logistic_regression_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(X)# 输出预测结果(概率值)
print(predictions[:5])  # 打印前5个预测结果

这里我们通过加载的模型对输入数据X进行预测。由于是二分类问题,模型会输出一个概率值,我们可以根据这个概率值将其转换为标签(例如,概率大于0.5为正类)。

7. 总结

在本篇博客中,我们学习了如何使用numpy生成自定义数据集,使用TensorFlow框架构建并训练逻辑回归模型,保存模型并在之后加载模型进行预测。通过这种方式,你可以在训练完成后方便地保存和加载模型,从而实现模型的持久化,便于后续的应用和部署。

希望这篇博客对你理解逻辑回归TensorFlow的使用有所帮助!你可以在此基础上扩展应用到更复杂的模型和数据集。


http://www.ppmy.cn/news/1569647.html

相关文章

Elasticsearch Queries

Elasticsearch Compound Queries Elasticsearch 的 Compound Queries 是一种强大的工具,用于组合多个查询子句,以实现更复杂的搜索逻辑。这些查询子句可以是叶查询(Leaf Queries)或复合查询(Compound Queries&#xf…

P3199 【[HNOI2009]最小圈】

疑似三倍经验 因为和机房一些大佬一起做的这道题,所以emmm他们貌似也写了题解,在做这道题的时候也参照了其他大佬写的一些题解,所以如果程序有雷同请见谅(手动鞠躬) 题目也是莫名其妙地给了一大串数学式,简…

记录一次-Rancher通过UI-Create Custom- RKE2的BUG

一、下游集群 当你的下游集群使用Mysql外部数据库时,会报错: **他会检查ETCD。 但因为用的是Mysql外部数据库,这个就太奇怪了,而且这个检测不过,集群是咩办法被管理的。 二、如果不选择etcd,就选择控制面。 在rke2-…

Qt网络相关

“ 所有生而孤独的人,葆有的天真 ” 为了⽀持跨平台, QT对⽹络编程的 API 也进⾏了重新封装。本章会上手一套基于QT的网络通信编写。 UDP Socket 在使用Qt进行网络编程前,需要在Qt项目中的.pro文件里添加对应的网络模块( network ). QT core gui net…

民法学学习笔记(个人向) Part.2

民法学学习笔记(个人向) Part.2 民法始终在解决两个生活中的核心问题: 私法自治;交易安全; 3. 自然人 3.4 个体工商户、农村承包经营户 都是特殊的个体经济单位; 3.4.1 个体工商户 是指在法律的允许范围内,依法经…

自定义数据集 ,使用朴素贝叶斯对其进行分类

代码: # 导入必要的库 import numpy as np import matplotlib.pyplot as plt# 定义类1的数据点,每个数据点是二维的坐标 class1_points np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])# 定义类2的数据点&…

【JDBC】数据库连接的艺术:深入解析数据库连接池、Apache-DBUtils与BasicDAO

文章目录 前言🌍 一.连接池❄️1. 传统获取Conntion问题分析❄️2. 数据库连接池❄️3.连接池之C3P0技术🍁3.1关键特性🍁3.2配置选项🍁3.3使用示例 ❄️4. 连接池之Druid技术🍁 4.1主要特性🍁 4.2 配置选项…

课题推荐——基于自适应滤波技术的多传感器融合在无人机组合导航中的应用研究

无人机在现代航空、农业和监测等领域的应用日益广泛。为了提高导航精度,通常采用多传感器融合技术,将来自GPS、惯性测量单元(IMU)、磁力计等不同传感器的数据整合。然而,传感器的量测偏差、环境干扰以及非线性特性使得…