七。自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

server/2025/2/9 14:26:34/

import tensorflow as tf
import numpy as np

# 自定义数据集类
class CustomDataset(tf.data.Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = tf.convert_to_tensor(x_data, dtype=tf.float32)
        self.y_data = tf.convert_to_tensor(y_data, dtype=tf.float32)

    def __iter__(self):
        for i in range(len(self.x_data)):
            yield (self.x_data[i], self.y_data[i])

# 逻辑回归模型
class LogisticRegressionModel(tf.keras.Model):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = tf.keras.layers.Dense(1, input_shape=(input_dim,), activation='sigmoid')

    def call(self, x):
        return self.linear(x)

# 创建数据集
x_data = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y_data = np.array([[0], [0], [1], [1], [1]], dtype=np.float32)
dataset = CustomDataset(x_data, y_data)

# 创建数据加载器
dataloader = dataset.batch(2).shuffle(100).repeat()

# 创建模型、损失函数和优化器
model = LogisticRegressionModel(input_dim=1)
loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# 训练模型
epochs = 100
for epoch in range(epochs):
    for x_batch, y_batch in dataloader:
        with tf.GradientTape() as tape:
            predictions = model(x_batch)
            loss = loss_object(y_batch, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy():.4f}')

# 保存模型
model.save('logistic_regression_model.h5')

# 加载模型
model = tf.keras.models.load_model('logistic_regression_model.h5')

# 进行预测
x_test = np.array([[6], [7], [8]], dtype=np.float32)
y_pred = model.predict(x_test)
print('预测值:', y_pred)
 


http://www.ppmy.cn/server/166240.html

相关文章

WPS中解除工作表密码保护(忘记密码)

1.下载vba插件 项目首页 - WPS中如何启用宏附wps.vba.exe下载说明分享:WPS中如何启用宏:附wps.vba.exe下载说明本文将详细介绍如何在WPS中启用宏功能,并提供wps.vba.exe文件的下载说明 - GitCode 并按照步骤安装 2.wps中点击搜索,输入开发…

数据结构:算法复杂度

前言 数据结构(Data Structure)是计算机存储、组织数据的方式,指相互之间存在一种或多种特定关系的数据元素的集合。没有一种单一的数据结构对所有用途都有用,所以我们要学各式各样的数据结构,如:线性表、树…

PyTorch Geometric(PyG)机器学习实战

PyTorch Geometric(PyG)机器学习实战 在图神经网络(GNN)的研究和应用中,PyTorch Geometric(PyG)作为一个基于PyTorch的库,提供了高效的图数据处理和模型构建功能。 本文将通过一个节…

在大型语言模型(LLM)框架内Transformer架构与混合专家(MoE)策略的概念整合

文章目录 传统的神经网络框架存在的问题一. Transformer架构综述1.1 transformer的输入1.1.1 词向量1.1.2 位置编码(Positional Encoding)1.1.3 编码器与解码器结构1.1.4 多头自注意力机制 二.Transformer分步详解2.1 传统词向量存在的问题2.2 详解编解码…

ES管理器焕新升级:紫色银狼主题来袭!

ES管理器(安卓版)迎来了一次令人眼前一亮的改头换面!此次更新最直观的变化集中在UI界面设计上。开发团队大胆突破,摒弃了以往稍显平庸的风格,引入了极具个性的全新主题——以热门游戏《崩坏:星穹铁道》中的…

Java实战经验分享

1. 项目优化与性能提升 面试问题: 聊聊你印象最深刻的项目,或者做了哪些优化 你在项目中如何解决缓存穿透问题? 缓存穿透是我们做缓存优化时最常遇到的问题,特别是当查询的对象在数据库中不存在时,缓存层和数据库都会…

800G光模块:引领未来数据中心与网络通信的新引擎

随着5G、云计算、人工智能和大数据技术的飞速发展,全球数据流量呈现爆发式增长。据预测,到2025年,全球数据总量将达到175ZB(泽字节),这对网络带宽和传输效率提出了前所未有的挑战。在这一背景下&#xff0c…

白嫖RTX 4090?Stable Diffusion:如何给线稿人物快速上色?

大家都知道,在设计的初期,我们通常会先绘制草图,然后再进行上色处理,最终才开始进行最终的设计工作。在这个上色的过程中,配色是至关重要的一环。这不仅方便了内部同事的评审,也让产品方和客户可以直观地了…