自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

embedded/2025/2/5 17:53:02/

代码:

import torch
import numpy as np
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score# 定义数据:x_data 是特征,y_data 是标签(目标值)
data = [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]]# 将数据转为 numpy 数组
data = np.array(data)# 提取 x_data 和 y_data
x_data = data[:, 0]  # 取第一列作为输入特征
y_data = data[:, 1]  # 取第二列作为目标标签# 将数据转换为 PyTorch 张量
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标标签# 使用 TensorDataset 来创建一个数据集
from torch.utils.data import DataLoader, TensorDatasetdataset = TensorDataset(x_train, y_train)  # 使用训练数据创建数据集
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # 将数据集转换为 DataLoader,批大小为 2,且每个 epoch 都会随机打乱数据# 定义损失函数:均方误差损失 (MSELoss)
criterion = nn.MSELoss()# 定义线性回归模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()# 使用一个线性层,输入为1维,输出为1维self.layers = nn.Linear(1, 1)def forward(self, x):# 直接返回线性层的输出return self.layers(x)model=LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoches =500
for n in range(1,epoches+1):epoch_loss=0#以前都是所有数据一块训练,现在是按照批次进行训练for batch_x,batch_y in dataloader:#现在x_train 相当于10个样本,但是现在维度,添加一个维度#10x1   变成样本 x 维度形式y_prd=model(batch_x.unsqueeze(1))#计算损失#y_prd在前面,y_true 是后面batch_loss=criterion(y_prd.squeeze(1),batch_y)#梯度更新#清空之前存储在优化器中的梯度optimizer.zero_grad()#损失函数对模型参数的梯度batch_loss.backward()#根据优化算法更新参数optimizer.step()#计算一下epoch的损失epoch_loss=epoch_loss+batch_loss# 5、显示频率设置#计算一下epoch的平均损失avg_loss=epoch_loss/(len(dataloader))# 不先画图if n % 10 == 0 or n == 1:print(f"epoches:{n},loss:{avg_loss}")torch.save(model.state_dict(),'model.pth')model.load_state_dict(torch.load("model.pth"))
#评估模型
# 评估模型一定要加下面这句话
model.eval()
# 定义数据
x_test=torch.tensor([[1.8]],dtype=torch.float32)
#添加上下文不需要计算梯度
with torch.no_grad():y_pred=model(x_test)threshold = 50  # 设定阈值
y_pred_class = int(y_pred.item() > threshold)# 输出预测结果
print(f"预测值 : {y_pred.item():.4f}")
print(f"预测类 : {y_pred_class}")# 假设真实标签也是 1 或 0,我们用一个假的真实标签来计算评估指标(你可以根据实际情况替换)
y_true_class = 1 if y_data[1] > threshold else 0  # 假设我们预测的是第二个样本# 计算精确度、召回率和 F1 分数
accuracy = accuracy_score([y_true_class], [y_pred_class])
precision = precision_score([y_true_class], [y_pred_class])
recall = recall_score([y_true_class], [y_pred_class])
f1 = f1_score([y_true_class], [y_pred_class])# 输出分类评估指标
print(f"precision : {precision:.4f}")
print(f"recall : {recall:.4f}")
print(f"f1 : {f1:.4f}")

结果:


http://www.ppmy.cn/embedded/159807.html

相关文章

二、CSS笔记

(一)css概述 1、定义 CSS是Cascading Style Sheets的简称,中文称为层叠样式表,用来控制网页数据的表现,可以使网页的表现与数据内容分离。 2、要点 怎么找到标签怎么操作标签对象(element) 3、css的四种引入方式 3.1 行内式 在标签的style属性中设定CSS样式。这种方…

小程序设计和开发:要如何明确目标和探索用户需求?

一、明确小程序的目标 确定业务目标 首先,需要明确小程序所服务的业务领域和目标。例如,是一个电商小程序,旨在促进商品销售;还是一个服务预约小程序,方便用户预订各类服务。明确业务目标有助于确定小程序的核心功能和…

(10) 如何获取 linux 系统上的 TCP 、 UDP 套接字的收发缓存的默认大小,以及代码范例

(1) 先介绍下后面的代码里要用到的基础函数: 以及: (2) 接着给出现代版的 读写 socket 参数的系统函数 : 以及: (3) 给出 一言的 范例代码,获取…

HarmonyOS NEXT:保存应用数据

用户首选项使用 用户首选项的特点 数据体积小、访问频率高、有加载速度要求的数据如用户偏好设置、用户字体大小、应用的配置参数。 用户搜选项(Preferences)提供了轻量级配置数据的持久化能力,支持订阅数据变化的通知能力。不支持分布式同…

全面掌握市场信息:xtquant库在证券品种数据获取中的应用

全面掌握市场信息:xtquant库在证券品种数据获取中的应用 开篇点题:技术背景和应用场景 在量化交易领域,快速准确地获取市场基础信息是至关重要的。xtquant库提供了一种便捷的途径来获取各类证券品种的数据,包括股票、指数、基金等…

【LeetCode 刷题】贪心算法(1)-基础

此博客为《代码随想录》二叉树章节的学习笔记,主要内容为贪心算法基础的相关题目解析。 文章目录 455.分发饼干1005.K次取反后最大化的数组和860.柠檬水找零 455.分发饼干 题目链接 class Solution:def findContentChildren(self, g: List[int], s: List[int]) -…

第九篇:NoSQL 数据库与大数据

第九篇:NoSQL 数据库与大数据 目标读者: 本篇文章适合那些希望学习 NoSQL(非关系型数据库)和大数据处理技术的学习者。如果你对传统的关系型数据库(如 MySQL、PostgreSQL)有一定了解,并希望扩…

5分钟掌握React的Redux Toolkit + Redux

Redux Toolkit Redux 教程 1. 引言 本教程介绍如何使用 Redux Toolkit(RTK) 和 TypeScript 搭建 Redux 状态管理系统。 我们将创建一个 计数器(Counter) 和 待办事项(Todo) 模块,并学习 Redu…