使用PyTorch实现逻辑回归:从训练到模型保存与加载

ops/2025/2/2 2:04:00/

1. 引入必要的库

首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


2. 加载自定义数据集

有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)


3. 创建数据集和数据加载器

使用PyTorch的TensorDatasetDataLoader来创建数据集和数据加载器。

# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


4. 定义逻辑回归模型

使用PyTorch的nn.Module来定义逻辑回归模型。

class LogisticRegression(nn.Module):def __init__(self, input_dim):super(LogisticRegression, self).__init__()self.linear = nn.Linear(input_dim, 1)def forward(self, x):outputs = torch.sigmoid(self.linear(x))return outputs# 初始化模型
input_dim = X.shape[1]
model = LogisticRegression(input_dim)

5. 训练模型

定义损失函数和优化器,然后训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 100
for epoch in range(num_epochs):for inputs, labels in train_loader:# 前向传播outputs = model(inputs)loss = criterion(outputs.flatten(), labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6. 保存模型

训练完成后,可以使用PyTorch的torch.save函数来保存模型。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')


7. 加载模型并进行预测

在需要时,可以使用torch.load函数加载模型,并进行预测。

# 加载模型
model = LogisticRegression(input_dim)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()# 进行预测
with torch.no_grad():sample_inputs = torch.tensor(X[:5]).float()  # 示例输入predictions = model(sample_inputs)predicted_labels = (predictions.flatten() > 0.5).int()print("Predicted Labels:", predicted_labels.numpy())


http://www.ppmy.cn/ops/154907.html

相关文章

CSS关系选择器详解

CSS关系选择器详解 学习前提什么是关系选择器?后代选择器(Descendant Combinator)语法示例注意事项 子代选择器(Child Combinator)语法示例注意事项 邻接兄弟选择器(Adjacent Sibling Combinator&#xff0…

双指针c++

双指针(Two Pointers)是一种常用的算法技巧,通常用于解决数组或链表中的问题,如滑动窗口、区间合并、有序数组的两数之和等。双指针的核心思想是通过两个指针的移动来优化时间复杂度,通常可以将 (O(n^2)) 的暴力解法优…

Titans 架构下MAC变体的探究

目前业界流行的 Transformer 模型架构虽然在大多数场景表现优秀,但其上下文窗口(Window)长度的限制,通常仅为几千到几万个 Token,这使得它们在处理长文本、多轮对话或需要大规模上下文记忆的任务中,往往无法…

flume和kafka整合 flume和kafka为什么一起用?

‌Flume和Kafka一起使用的主要原因是为了实现高效、可靠的数据采集和实时处理。‌‌12 实时流式日志处理的需求 Flume和Kafka结合使用的主要目的是为了完成实时流式的日志处理。Flume负责数据的采集和传输,而Kafka则作为消息缓存队列,能够有效地缓冲数据,防止数据堆积或丢…

Vue.js组件开发-实现导出PDF文件可自定义添加水印及水印样式方向

使用 Vue 实现导出 PDF 文件并添加水印,同时支持设置水印样式、方向和自定义水印内容。 步骤 安装依赖:使用 html2canvas 将 HTML 内容转换为 canvas,使用 jspdf 生成 PDF 文件。创建 Vue 组件:在组件中实现水印生成、HTML 转 c…

一文介绍Hive数据类型

一文介绍Hive数据类型 文章目录 一文介绍Hive数据类型写在前面基本数据类型集合数据类型介绍案例实操 类型转化隐式类型转换CAST操作 写在前面 Linux版本:CentOS7.5Hive版本:Hive-3.1.2 基本数据类型 如下表所示: Hive数据类型Java数据类型…

重构字符串(767)

767. 重构字符串 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:string reorganizeString(string s){string res;//因为1 < s.length < 500 &#xff0c; uint64_t 类型足够uint16_t n s.size();if (n 0) {return res;}unordere…

“com.docker.vmnetd”将对你的电脑造成伤害。 如何解决 |Mac

电脑型号&#xff1a;Macbook pro &#xff08;Apple M3 Pro&#xff09; 系统版本&#xff1a;15.2 打开电脑突然提示“com.docker.vmnetd”将对你的电脑造成伤害&#xff0c;执行以下操作 # 停掉 Docker 服务 sudo pkill [dD]ocker# 停掉 vmnetd 服务 sudo launchctl boot…