自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

news/2025/1/30 19:45:49/

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。


http://www.ppmy.cn/news/1567960.html

相关文章

【C++动态规划 状态压缩】2597. 美丽子集的数目|2033

本文涉及知识点 C动态规划 LeetCode2597. 美丽子集的数目 给你一个由正整数组成的数组 nums 和一个 正 整数 k 。 如果 nums 的子集中,任意两个整数的绝对差均不等于 k ,则认为该子数组是一个 美丽 子集。 返回数组 nums 中 非空 且 美丽 的子集数目。…

http跳转https

1、第一种:不好使 在nginx的配置中,在https的server站点添加如下头部: add_header Strict-Transport-Security “max-age63072000; includeSubdomains; preload”; 这样当第一次以https方式访问我的网站,nginx则会告知客户端的浏览…

研发的立足之本到底是啥?

0 你的问题,我知道! 本文深入T型图“竖线”的立足之本:专业技术 技术赋能业务能力。研发在学习投入精力最多,也误区最多。 某粉丝感发展遇到瓶颈,项目都会做,但觉无提升,想跳槽。于是&#x…

react引入DingTalk-JinBuTi字体

要在React项目中引入DingTalk-JinBuTi字体,可以按照以下步骤操作: 1. 下载字体文件: 如果DingTalk-JinBuTi字体不是通过npm或yarn可直接安装的包,则需要从官方渠道下载字体文件。这通常包括.woff, .ttf, .eot, .svg等格式…

(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台

1 项目简介(开源地址在文章结尾) 系统旨在为了帮助鸟类爱好者、学者、动物保护协会等群体更好的了解和保护鸟类动物。用户群体可以通过平台采集野外鸟类的保护动物照片和视频,甄别分类、实况分析鸟类保护动物,与全世界各地的用户&…

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-head.py

head.py ultralytics\nn\modules\head.py 目录 head.py 1.所需的库和模块 2.class Detect(nn.Module): 3.class Segment(Detect): 4.class OBB(Detect): 5.class Pose(Detect): 6.class Classify(nn.Module): 7.class WorldDetect(Detect): 8.class RTDETRDec…

BGP分解实验·15——路由阻尼(抑制/衰减)实验

一个可以监控路由信息不稳定征兆的小特性,那些表现出不稳定的路由将会受到惩罚,直到它稳定下来为止。 实验拓扑如下: 配置两台路由器的基础连通性后,再到R2上设置半衰期5分钟、使用阈值750,惩罚阈值1500;并…

浅谈Redis

2007 年,一位程序员和朋友一起创建了一个网站。为了解决这个网站的负载问题,他自己定制了一个数据库。于2009 年开发,称之为Redis。这位意大利程序员是萨尔瓦托勒桑菲利波(Salvatore Sanfilippo),他被称为Redis之父,更…