使用PyTorch实现逻辑回归:从训练到模型保存与加载

news/2025/2/1 23:08:18/

1. 引入必要的库

首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


2. 加载自定义数据集

有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)


3. 创建数据集和数据加载器

使用PyTorch的TensorDatasetDataLoader来创建数据集和数据加载器。

# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


4. 定义逻辑回归模型

使用PyTorch的nn.Module来定义逻辑回归模型。

class LogisticRegression(nn.Module):def __init__(self, input_dim):super(LogisticRegression, self).__init__()self.linear = nn.Linear(input_dim, 1)def forward(self, x):outputs = torch.sigmoid(self.linear(x))return outputs# 初始化模型
input_dim = X.shape[1]
model = LogisticRegression(input_dim)

5. 训练模型

定义损失函数和优化器,然后训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 100
for epoch in range(num_epochs):for inputs, labels in train_loader:# 前向传播outputs = model(inputs)loss = criterion(outputs.flatten(), labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6. 保存模型

训练完成后,可以使用PyTorch的torch.save函数来保存模型。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')


7. 加载模型并进行预测

在需要时,可以使用torch.load函数加载模型,并进行预测。

# 加载模型
model = LogisticRegression(input_dim)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()# 进行预测
with torch.no_grad():sample_inputs = torch.tensor(X[:5]).float()  # 示例输入predictions = model(sample_inputs)predicted_labels = (predictions.flatten() > 0.5).int()print("Predicted Labels:", predicted_labels.numpy())


http://www.ppmy.cn/news/1568536.html

相关文章

DeepSeek R1与OpenAI o1深度对比

文章目录 引言技术原理DeepSeek R1OpenAI o1 性能表现官方数据推理任务知识密集型任务通用能力 价格对比应用场景科研与技术开发自然语言处理(NLP)企业智能化升级教育与培训数据分析与智能决策 部署与集成DeepSeek R1OpenAI o1 伦理考量DeepSeek R1OpenA…

实验十 数据库完整性实验

实验十 数据库完整性实验 一、实验目的 1、熟悉通过SQL对数据进行完整性控制。熟练掌握数据库三类完整性约束(实体完整性、用户自定义完整性、参照完整性) 2、了解SQL SERVER 的违反完整性处理措施。 3、了解主键(PRIMARY KEY)约…

Matrials studio 软件安装步骤(百度网盘链接)

软件简介: Materials Studio是一款材料模拟软件。帮助建立三维结构模型,并对各种晶体、无定型以及高分子材料的性质及相关过程进行深入的研究。 网盘链接: https://pan.baidu.com/s/1h2yuuH6RQixpuWveJP4KDA?pwd22o9 提取码:22o9 安装…

无人机微波图像传输数据链技术详解

无人机微波图像传输数据链技术是无人机通信系统中的关键组成部分,它确保了无人机与地面站之间高效、可靠的图像数据传输。以下是对该技术的详细解析: 一、技术原理 无人机微波图像传输数据链主要基于微波通信技术实现。在数据链路中,图像数…

系统架构设计基础:概念与原则

系统架构设计基础:概念与原则 引言 系统架构设计是软件开发过程中至关重要的一环,它决定了系统的整体结构、组件之间的关系以及系统的可扩展性、可维护性和性能。系统架构设计师不仅需要具备扎实的技术功底,还需要对业务需求有深刻的理解,能够在复杂的需求中找到平衡点,…

MinDoc 安装与部署

下载可执行文件 mindoc mindoc_linux_amd64.zip 上传并解压压缩包 cd /opt mkdir mindoc cd mindocunzip mindoc_linux_amd64.zip 创建数据库 CREATE DATABASE mindoc_db DEFAULT CHARSET utf8mb4 COLLATE utf8mb4_general_ci; 配置数据库 将解压目录下 conf/app.conf.exam…

分布式微服务系统架构第87集:kafka

Kafka 就是为了解决上述问题而设计的一款基于发布与订阅的消息系统。它一般被称为 “分布式提交日志”或者“分布式流平台”。文件系统或数据库提交日志用来提供所有事务 的持久记录,通过重放这些日志可以重建系统的状态。同样地,Kafka 的数据是按照一定…

安卓通过网络获取位置的方法

一 方法介绍 1. 基本权限设置 首先需要在 AndroidManifest.xml 中添加必要权限&#xff1a; xml <uses-permission android:name"android.permission.INTERNET" /> <uses-permission android:name"android.permission.ACCESS_NETWORK_STATE" /&g…