使用TensorFlow实现逻辑回归:从训练到模型保存与加载

news/2025/2/4 0:01:06/

1. 引入必要的库

首先,需要引入必要的库。TensorFlow用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用TensorFlow的Keras接口来构建逻辑回归模型。

# 构建逻辑回归模型
model = Sequential([Dense(1, activation='sigmoid', input_shape=(X.shape[1],))
])# 编译模型
model.compile(optimizer=SGD(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])


4. 训练模型

使用自定义数据集训练模型。

# 训练模型
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)


5. 保存模型

训练完成后,可以使用TensorFlow的save方法保存模型。

# 保存模型
model.save('logistic_regression_model.h5')


6. 加载模型并进行预测

在需要时,可以使用TensorFlow的load_model方法加载模型,并进行预测。

# 加载模型
from tensorflow.keras.models import load_modelloaded_model = load_model('logistic_regression_model.h5')# 进行预测
predictions = loaded_model.predict(X[:5])
predicted_labels = (predictions > 0.5).astype(int)print("Predicted Labels:", predicted_labels.flatten())


7. 结果可视化

可以绘制训练过程中的损失和准确率变化曲线,以帮助理解模型的性能。

# 绘制训练和验证的损失曲线
plt.plot(history.history['loss'], label='Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 绘制训练和验证的准确率曲线
plt.plot(history.history['accuracy'], label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


http://www.ppmy.cn/news/1569092.html

相关文章

浅谈 JSON 对象和 FormData 相互转换,打通前端与后端的通信血脉_json转formdata

formData 请求头: formData 负荷: 通过上面的几张图我们就能大概明白了,前端传的都是二进制数据,两者的 content-type 是不同的,json 我们已经序列化好了,而 formdata 还是需要进行处理。 formdata 的两种格…

ubuntu无法上网的解决办法

Ubuntu系统无法联网可能有多种原因,以下是一些常见的排查步骤和解决方法: 1. 检查网络连接状态 首先,确认网络接口是否已启用。 ip a查看网络接口(如eth0、wlan0)是否有IP地址。如果没有,可能是接口未启…

基于Python的药物相互作用预测模型AI构建与优化(上.文字部分)

一、引言 1.1 研究背景与意义 在临床用药过程中,药物相互作用(Drug - Drug Interaction, DDI)是一个不可忽视的重要问题。当患者同时服用两种或两种以上药物时,药物之间可能会发生相互作用,从而改变药物的疗效、增加不良反应的发生风险,甚至危及患者的生命安全。例如,…

Rust语言的编程范式

Rust语言的编程范式 引言 在现代编程语言中,Rust以其独特的编程范式和内存管理机制脱颖而出。Rust不仅关注性能与安全性,还通过其独特的语法和强大的工具链,引导开发者采取更好的编程实践。本文将深入探讨Rust语言的编程范式,包…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.12 连续数组:为什么contiguous这么重要?

2.12 连续数组:为什么contiguous这么重要? 目录 #mermaid-svg-wxhozKbHdFIldAkj {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-wxhozKbHdFIldAkj .error-icon{fill:#552222;}#mermaid-svg-…

【C++动态规划 离散化】1626. 无矛盾的最佳球队|2027

本文涉及知识点 C动态规划 离散化 LeetCode1626. 无矛盾的最佳球队 假设你是球队的经理。对于即将到来的锦标赛,你想组合一支总体得分最高的球队。球队的得分是球队中所有球员的分数 总和 。 然而,球队中的矛盾会限制球员的发挥,所以必须选…

MySQL基本架构SQL语句在数据库框架中的执行流程数据库的三范式

MySQL基本架构图: MySQL主要分为Server层和存储引擎层 Server层: 连接器:连接客户端,获取权限,管理连接 查询缓存(可选):在执行查询语句之前会先到查询缓存中查看是否执行过这条语…

SQL UCASE() 函数详解

SQL UCASE() 函数详解 在SQL中,UCASE() 函数是一个非常有用的字符串处理函数,它可以将字符串中的所有小写字母转换为大写字母。本文将详细介绍UCASE() 函数的用法、语法、示例以及其在实际应用中的优势。 一、UCASE() 函数简介 UCASE() 函数是SQL标准…