使用numpy自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预

server/2025/2/1 23:37:43/

1. 引言

逻辑回归(Logistic Regression)是一种常见的分类算法,广泛应用于二分类问题。在本篇博客中,我们将使用numpy生成一个简单的自定义数据集,并使用TensorFlow框架构建和训练逻辑回归模型。训练完成后,我们会保存模型,并演示如何加载保存的模型进行预测。

2. 创建自定义数据集

首先,我们使用numpy生成一个简单的二分类数据集,包含两个特征和对应的标签。标签0表示负类,标签1表示正类。

import numpy as np# 设置随机种子,保证每次运行结果一致
np.random.seed(42)# 生成自定义数据集
X = np.random.rand(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 标签:如果两个特征之和大于1,标签为1,否则为0

这样我们就得到了一个简单的二分类数据集,X是特征矩阵,y是标签。

3. 构建逻辑回归模型

接下来,我们使用TensorFlow来构建逻辑回归模型。逻辑回归本质上是一个线性模型,通过Sigmoid函数输出概率,最终将其转化为二分类标签。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense# 构建模型
model = Sequential([Dense(1, input_dim=2, activation='sigmoid')  # 输入层2个特征,输出层1个节点,sigmoid激活函数
])# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

在这个模型中,我们只用了一个包含一个节点的输出层,使用sigmoid激活函数来输出分类概率。损失函数选择了binary_crossentropy,它是二分类问题常用的损失函数。

4. 训练模型

现在,我们使用生成的数据集来训练模型。

# 训练模型
model.fit(X, y, epochs=50, batch_size=10, verbose=1)

训练过程将进行50个周期,每批次包含10个样本。你可以根据自己的需求调整epochsbatch_size

5. 保存模型

训练完成后,我们将保存模型。TensorFlow提供了方便的保存方法,可以将整个模型(包括模型架构、权重和训练配置)保存在一个文件中。

# 保存模型
model.save('logistic_regression_model.h5')

保存后的模型文件logistic_regression_model.h5将包含模型的所有信息,稍后我们可以重新加载这个模型来进行预测。

6. 加载模型并进行预测

保存的模型可以在后续的工作中重新加载并使用。我们通过tensorflow.keras.models.load_model()来加载保存的模型。

# 加载模型
loaded_model = tf.keras.models.load_model('logistic_regression_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(X)# 输出预测结果(概率值)
print(predictions[:5])  # 打印前5个预测结果

这里我们通过加载的模型对输入数据X进行预测。由于是二分类问题,模型会输出一个概率值,我们可以根据这个概率值将其转换为标签(例如,概率大于0.5为正类)。

7. 总结

在本篇博客中,我们学习了如何使用numpy生成自定义数据集,使用TensorFlow框架构建并训练逻辑回归模型,保存模型并在之后加载模型进行预测。通过这种方式,你可以在训练完成后方便地保存和加载模型,从而实现模型的持久化,便于后续的应用和部署。

希望这篇博客对你理解逻辑回归TensorFlow的使用有所帮助!你可以在此基础上扩展应用到更复杂的模型和数据集。


http://www.ppmy.cn/server/164193.html

相关文章

磁感应编码器实现原理和C语言实现

目录 概述 1 核心物理原理 2 硬件结构设计 2.1 磁栅组件 2.2 传感器阵列 3 信号处理流程 4 关键技术突破 5 典型应用对比 6 实际应用案例 7 C语言的算法实现 7.1 核心实现原理 7.1.1 磁场空间分布建模 7.1.2 正交信号生成 7.2 完整C语言实现代码 7.3 应用层实现…

LangChain教程 - RAG - PDF解析

在现代人工智能和自然语言处理(NLP)应用中,处理PDF文档是一项常见且重要的任务。由于PDF格式的复杂性,包含文本、图像、表格等多种内容结构,高效、准确地解析PDF需要强大的工具支持。LangChain提供了一套完善的PDF加载…

leetcode 2300. 咒语和药水的成功对数

题目如下 数据范围 示例 注意到n和m的长度最长达到10的5次方所以时间复杂度为n方的必然超时。 因为题目要求我们返回每个位置的spell对应的有效对数所以我们只需要找到第一个有效的药水就行,这里可以先对potions排序随后使用二分查找把时间复杂度压到nlogn就不会…

C++学习第五天

创作过程中难免有不足,若您发现本文内容有误,恳请不吝赐教。 提示:以下是本篇文章正文内容,下面案例可供参考 一、构造函数 问题1 关于编译器生成的默认成员函数,很多童鞋会有疑惑:不实现构造函数的情况下…

C语言小项目——通讯录

功能介绍: 1.联系人信息:姓名年龄性别地址电话 2.通讯录中可以存放100个人的信息 3.功能: 1>增加联系人 2>删除指定联系人 3>查找指定联系人的信息 4>修改指定联系人的信息 5显示所有联系人的信息 6>排序(名字&…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》038-使用CSS3创建动画(keyframes动画)

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主&…

Python 梯度下降法(五):Adam Optimize

文章目录 Python 梯度下降法(五):Adam Optimize一、数学原理1.1 介绍1.2 符号说明1.3 实现流程 二、代码实现2.1 函数代码2.2 总代码2.3 遇到的问题2.4 算法优化 三、优缺点3.1 优点3.2 缺点 四、相关链接 Python 梯度下降法(五&a…

【Android】问deepseek存储访问

这些天deepseek爆火,我们来问问android问题看看,如果问android中的应用怎么访问外部存储,回答的很清楚,但是如果问的深入一些,比如Android中是怎么控制让应用不能读取其他应用的外部存储文件的,回答的比较抽…