使用 PyTorch 实现逻辑回归并评估模型性能

ops/2025/2/4 4:47:08/

1. 逻辑回归简介

逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分类为负类。

2. 数据准备

为了演示逻辑回归的效果,我们构造了一个简单的二维数据集,包含两类样本。每类样本有 7 个数据点,特征维度为 2。

class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])

我们将这两类数据点的特征合并,并为每个数据点分配标签(0 表示第一类,1 表示第二类)。

3. 模型构建

我们使用 PyTorch 框架来实现逻辑回归模型。模型结构非常简单,仅包含一个线性层和一个 Sigmoid 激活函数。

class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))

4. 模型训练

我们使用二分类交叉熵损失函数(BCELoss)和随机梯度下降优化器(SGD)来训练模型。训练过程如下:

epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。

5. 模型保存与加载

训练完成后,我们将模型的权重保存到文件中,方便后续加载和使用。

torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

加载模型时,我们创建一个新的模型实例,并使用 load_state_dict 方法加载保存的权重。

loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth', map_location=torch.device('cpu')))
loaded_model.eval()

6. 模型预测与性能评估

加载模型后,我们使用模型对训练数据进行预测,并计算精确度、召回率和 F1 分数。

with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

7. 运行结果

运行上述代码后,我们得到了以下结果:

  • 实际结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 预测结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 精确度(Precision):1.0000

  • 召回率(Recall):1.0000

  • F1 分数:1.0000

从结果可以看出,模型在训练集上表现良好,精确度、召回率和 F1 分数均为 1.0000。

8. 完整代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score"""使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数"""
# 提取特征和标签
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)# 定义逻辑回归模型
class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',map_location=torch.device('cpu'),weights_only=True))
loaded_model.eval()# 进行预测
with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()# 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")


http://www.ppmy.cn/ops/155477.html

相关文章

Flutter 与 React 前端框架对比:深入分析与实战示例

Flutter 与 React 前端框架对比:深入分析与实战示例 在现代前端开发中,Flutter 和 React 是两个非常流行的框架。Flutter 是 Google 推出的跨平台开发框架,支持从一个代码库生成 iOS、Android、Web 和桌面应用;React 则是 Facebo…

DeepSeek大模型技术解析:从架构到应用的全面探索

一、引言 在人工智能领域,大模型的发展日新月异,其中DeepSeek大模型凭借其卓越的性能和广泛的应用场景,迅速成为业界的焦点。本文旨在深入剖析DeepSeek大模型的技术细节,从架构到应用进行全面探索,以期为读者提供一个…

低代码系统-产品架构案例介绍、炎黄盈动-易鲸云(十二)

易鲸云作为炎黄盈动新推出的产品,在定位上为低零代码产品。 开发层 表单引擎 表单设计器,包括设计和渲染 流程引擎 流程设计,包括设计和渲染,需要说明的是:采用国际标准BPMN2.0,可以全球通用 视图引擎 视图…

openmv的端口被拆分为两个 导致电脑无法访问openmv文件系统解决办法 openmv USB功能改动 openmv驱动被更改如何修复

我之前误打误撞遇到一次,直接把openmv的全部端口删除卸载然后重新插上就会自动重新装上一个openmv端口修复成功,大家可以先试试不行再用下面的方法 全部卸载再重新插拔openmv 要解决OpenMV IDE中出现的两个端口问题,可以尝试以下步骤&#x…

【2025年最新版】Java JDK安装、环境配置教程 (图文非常详细)

文章目录 【2025年最新版】Java JDK安装、环境配置教程 (图文非常详细)1. JDK介绍2. 下载 JDK3. 安装 JDK4. 配置环境变量5. 验证安装6. 创建并测试简单的 Java 程序6.1 创建 Java 程序:6.2 编译和运行程序:6.3 在显示或更改文件的…

Java实现.env文件读取敏感数据

文章目录 1.common-env-starter模块1.目录结构2.DotenvEnvironmentPostProcessor.java 在${xxx}解析之前执行,提前读取配置3.EnvProperties.java 这里的path只是为了代码提示4.EnvAutoConfiguration.java Env模块自动配置类5.spring.factories 自动配置和注册Enviro…

rust跨平台调用动态库

动态库在不同的操作系统&#xff0c;扩展名是不一样的&#xff0c;所以要做处理: static LIB: Lazy<Mutex<Option<Library>>> Lazy::new(|| Mutex::new(None));type CreateFunc unsafe extern "C" fn(*const c_char, *const c_char) -> c_int…

亚博microros小车-原生ubuntu支持系列:19 nav2 导航

开始小车测试之前&#xff0c;先补充下背景知识 nav2 Navigation2具有下列工具&#xff1a; 加载、提供和存储地图的工具&#xff08;地图服务器Map Server&#xff09; 在地图上定位机器人的工具 (AMCL) 避开障碍物从A点移动到B点的路径规划工具&#xff08;Nav2 Planner&a…