七。自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

devtools/2025/2/8 9:13:37/

import tensorflow as tf
import numpy as np

# 自定义数据集类
class CustomDataset(tf.data.Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = tf.convert_to_tensor(x_data, dtype=tf.float32)
        self.y_data = tf.convert_to_tensor(y_data, dtype=tf.float32)

    def __iter__(self):
        for i in range(len(self.x_data)):
            yield (self.x_data[i], self.y_data[i])

# 逻辑回归模型
class LogisticRegressionModel(tf.keras.Model):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = tf.keras.layers.Dense(1, input_shape=(input_dim,), activation='sigmoid')

    def call(self, x):
        return self.linear(x)

# 创建数据集
x_data = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y_data = np.array([[0], [0], [1], [1], [1]], dtype=np.float32)
dataset = CustomDataset(x_data, y_data)

# 创建数据加载器
dataloader = dataset.batch(2).shuffle(100).repeat()

# 创建模型、损失函数和优化器
model = LogisticRegressionModel(input_dim=1)
loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# 训练模型
epochs = 100
for epoch in range(epochs):
    for x_batch, y_batch in dataloader:
        with tf.GradientTape() as tape:
            predictions = model(x_batch)
            loss = loss_object(y_batch, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy():.4f}')

# 保存模型
model.save('logistic_regression_model.h5')

# 加载模型
model = tf.keras.models.load_model('logistic_regression_model.h5')

# 进行预测
x_test = np.array([[6], [7], [8]], dtype=np.float32)
y_pred = model.predict(x_test)
print('预测值:', y_pred)
 


http://www.ppmy.cn/devtools/157062.html

相关文章

图论——环检测

环检测以及拓扑排序 前言复习模版环检测-DFS版本环检测- BFS版本 前言 我觉得学习这些之前,一定要对图的数据结构和抽象模型有概念,并且图构建的代码模版应该手到擒来,不然还是挺折磨的,不是这差一点就是那差一点,写道力扣卡卡的非常烦人. 复习模版 我觉得单拿出来再说这个模…

Spring Boot整合MQTT

MQTT是基于代理的轻量级的消息发布订阅传输协议。 1、下载安装代理 进入mosquitto下载地址:Download | Eclipse Mosquitto,进行下载,以win版本为例 下载完成后,在本地文件夹找到下载的代理安装文件 使用管理员身份打开安装 安装…

【第一章】 操作系统的概述

目录 零、前言 0.1 考纲内容 0.2 考情统计 0.3 考点解读 0.4 复习建议 一、操作系统的基本概念 1.1 操作系统的概念 1.1.1 电脑的诞生过程 1.1.2 操作系统的定义 1.2 操作系统的功能 1.2.1 QQ聊天引入 1.2.2 处理器管理的功能 1.2.3 存储器管理的功能 1.2.4 文件…

Linux——基础命令2

1、用户 Linux是一个多用户多任务操作系统,任何一个要使用系统资源的用户,都必须首先向系统管理员申请一个账号,然后以这个账号的身份进入系统。 Linux系统支持多个用户在同一时间内登陆,不同用户可以执行不同的任务&#xff0c…

群晖NAS如何通过WebDAV和内网穿透实现Joplin笔记远程同步

文章目录 前言1. 检查群晖Webdav 服务2. 本地局域网IP同步测试3. 群晖安装Cpolar工具4. 创建Webdav公网地址5. Joplin连接WebDav6. 固定Webdav公网地址7. 公网环境连接测试 前言 在数字化浪潮的推动下,笔记应用已成为我们记录生活、整理思绪的重要工具。Joplin&…

(苍穹外卖)项目结构

苍穹外卖项目结构 后端工程基于 maven 进行项目构建,并且进行分模块开发。 1). 用 IDEA 打开初始工程,了解项目的整体结构: 对工程的每个模块作用说明: 序号名称说明1sky-take-outmaven父工程,统一管理依赖版本&…

​PDFsam Basic是一款 免费开源的PDF分割合并工具

PDFsam Basic 是一款功能强大的 PDF 工具,专为满足用户对 PDF 文件的各种操作需求而设计。它能够高效地拆分、合并、提取页面、混合以及旋转 PDF 文件,为用户提供灵活的文档处理解决方案。 合并 PDF 文件 PDF 合并是 PDFsam Basic 最受欢迎的功能之一。…

C++中的based for 循环

文章目录 范围基 for 循环(Range-based for Loop)语法格式例子1. 遍历数组2. 遍历 std::vector3. 使用引用避免拷贝4. 使用常量引用 特殊用法5. 遍历 std::map 或 std::unordered_map 总结 在 C 中,based for 循环并不是一种标准的语法&#…