【机器学习】—逻辑回归

embedded/2024/11/29 19:20:55/

逻辑回归实现详解

介绍

逻辑回归(Logistic Regression)是一种广泛应用于分类问题的统计模型,尤其适用于二分类问题。本文将通过一个简单的例子,使用Python和PyTorch库实现逻辑回归,并通过可视化展示模型的训练过程和最终结果。

环境准备

在开始之前,确保已经安装了以下库:

  • numpy
  • matplotlib
  • torch
  • sklearn
    可以使用以下命令安装这些库:
pip install numpy matplotlib torch scikit-learn

代码实现

数据生成

我们使用sklearn.datasets.make_blobs函数生成二分类数据集。该函数可以生成具有指定中心和标准差的高斯分布数据点。

from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import torch
import numpy as np
# 使用make_blobs随机生成n个样本
x, y = make_blobs(n_samples=200, centers=2, random_state=0, cluster_std=0.5)
x1 = x[:, 0]
x2 = x[:, 1]
# 可视化数据
plt.scatter(x1[y == 1], x2[y == 1], color='blue', marker='o')
plt.scatter(x1[y == 0], x2[y == 0], color='red', marker='x')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Generated Data')
plt.show()

模型定义

逻辑回归模型的基本形式为:
h ( θ ) = σ ( θ 0 + θ 1 x 1 + θ 2 x 2 ) h(\theta) = \sigma(\theta_0 + \theta_1 x_1 + \theta_2 x_2) h(θ)=σ(θ0+θ1x1+θ2x2)
其中, σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1 是sigmoid函数。

def hypothesis(theta0, theta1, theta2, x1, x2):z = theta0 + theta1 * x1 + theta2 * x2h = torch.sigmoid(z)return h.view(-1, 1)

损失函数

逻辑回归的损失函数通常使用对数损失函数(Log Loss):
J ( h , y ) = − 1 m ∑ i = 1 m [ y ( i ) log ⁡ ( h ( i ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − h ( i ) ) ] J(h, y) = -\frac{1}{m} \sum_{i=1}^{m} \left[ y^{(i)} \log(h^{(i)}) + (1 - y^{(i)}) \log(1 - h^{(i)}) \right] J(h,y)=m1i=1m[y(i)log(h(i))+(1y(i))log(1h(i))]

def J(h, y):return -torch.mean(y * torch.log(h) + (1 - y) * torch.log(1 - h))

模型训练

我们使用PyTorch的Adam优化器来训练模型。训练过程中,我们不断更新模型参数以最小化损失函数。

if __name__ == '__main__':# 数据准备x1 = torch.tensor(x1, dtype=torch.float32)x2 = torch.tensor(x2, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 初始化参数theta0 = torch.tensor(0.0, requires_grad=True)theta1 = torch.tensor(0.0, requires_grad=True)theta2 = torch.tensor(0.0, requires_grad=True)# 优化器optimizer = torch.optim.Adam([theta0, theta1, theta2])# 训练模型for epoch in range(10000):h = hypothesis(theta0, theta1, theta2, x1, x2)loss = J(h, y)loss.backward()optimizer.step()optimizer.zero_grad()if epoch % 1000 == 0:print(f'After {epoch} epochs, the loss is {loss.item():.3f}')# 获取训练后的参数w1 = theta1.item()w2 = theta2.item()b = theta0.item()# 可视化决策边界x = np.linspace(-1, 6, 100)d = -(w1 * x + b) * 1.0 / w2plt.scatter(x1[y == 1], x2[y == 1], color='blue', marker='o')plt.scatter(x1[y == 0], x2[y == 0], color='red', marker='x')plt.plot(x, d, color='green')plt.xlabel('Feature 1')plt.ylabel('Feature 2')plt.title('Decision Boundary')plt.show()

结果分析

输出:

after 0 ,the loss is 0.693
after 1000 ,the loss is 0.188
after 2000 ,the loss is 0.086
after 3000 ,the loss is 0.049
after 4000 ,the loss is 0.031
after 5000 ,the loss is 0.020
after 6000 ,the loss is 0.014
after 7000 ,the loss is 0.009
after 8000 ,the loss is 0.007
after 9000 ,the loss is 0.005

在这里插入图片描述

通过上述代码,我们可以生成二分类数据集并训练逻辑回归模型。训练过程中,损失函数逐渐减小,最终模型能够较好地拟合数据。最终的决策边界将数据集中的两个类别分开。

总结

本文通过一个简单的例子,详细介绍了如何使用Python和PyTorch实现逻辑回归模型。通过生成数据、定义模型、定义损失函数、训练模型和可视化结果,我们展示了逻辑回归的基本流程。希望本文对读者理解逻辑回归有所帮助。

参考资料

  • Scikit-learn官方文档
  • PyTorch官方文档
  • 逻辑回归的数学原理

http://www.ppmy.cn/embedded/141554.html

相关文章

Sofia-SIP 使用教程

Sofia-SIP 是一个开源的 SIP 协议栈,广泛用于 VoIP 和即时通讯应用。以下是一些基本的使用教程,帮助你快速上手 Sofia-SIP。 1. 安装 Sofia-SIP 首先,你需要安装 Sofia-SIP 库。你可以从其官方 GitHub 仓库克隆源代码并编译安装&#xff1a…

CodeIgniter URL结构

CodeIgniter 的URL 结构设计得简洁且易于管理。通常遵循以下模式&#xff1a; http://<domain>/<index_page>/<controller>/<method>/<parameters> 下面是每个部分的详细说明&#xff1a; <domain>&#xff1a; 这是你的网站域名&#…

HTML CSS JS基础考试题与答案

一、选择题&#xff08;2分/题&#xff09; 1&#xff0e;下面标签中&#xff0c;用来显示段落的标签是&#xff08; d &#xff09;。 A、<h1> B、<br /> C、<img /> D、<p> 2. 网页中的图片文件位于html文件的下一级文件夹img中&#xff0c;…

Python中的map函数

Python中的map函数是一种常用的高雅实现&#xff0c;它能够在不使用第三方库的情况下对一个列表进行映射&#xff0c;并返回一个新的列表。map函数不仅能够提高Python代码的可读性&#xff0c;还能够拓展Python的功能&#xff0c;使其成为一种强大的数据处理工具。 Python中的…

Vue3+node.js实现注册

文章目录 前端代码实现后端代码实现 效果图 前端代码实现 <template><div class"register-container"><el-card class"register-card"><template #header><div class"card-header"><span>注册</span&…

林业产品推荐系统:Spring Boot开发手册

3 系统分析 这部分内容虽然在开发流程中处于最开始的环节&#xff0c;但是它对接下来的设计和实现起着重要的作用&#xff0c;因为系统分析结果的好坏&#xff0c;将直接影响后面环节的开展。 3.1可行性研究 影响系统开发的因素有很多&#xff0c;比如开发成本高就不适合开展&a…

深入浅出:JVM 的架构与运行机制

一、什么是JVM 1、什么是JDK、JRE、JVM JDK是 Java语言的软件开发工具包&#xff0c;也是整个java开发的核心&#xff0c;它包含了JRE和开发工具包JRE&#xff0c;Java运行环境&#xff0c;包含了JVM和Java的核心类库&#xff08;Java API&#xff09;JVM&#xff0c;Java虚拟…

输入json 达到预览效果

下载 npm i vue-json-pretty2.4.0 <template><div class"newBranchesDialog"><t-base-dialogv-if"addDialogShow"title"Json数据配置"closeDialog"closeDialog":dialogVisible"addDialogShow":center"…