【机器学习】—逻辑回归

news/2024/11/29 7:35:02/

逻辑回归实现详解

介绍

逻辑回归(Logistic Regression)是一种广泛应用于分类问题的统计模型,尤其适用于二分类问题。本文将通过一个简单的例子,使用Python和PyTorch库实现逻辑回归,并通过可视化展示模型的训练过程和最终结果。

环境准备

在开始之前,确保已经安装了以下库:

  • numpy
  • matplotlib
  • torch
  • sklearn
    可以使用以下命令安装这些库:
pip install numpy matplotlib torch scikit-learn

代码实现

数据生成

我们使用sklearn.datasets.make_blobs函数生成二分类数据集。该函数可以生成具有指定中心和标准差的高斯分布数据点。

from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import torch
import numpy as np
# 使用make_blobs随机生成n个样本
x, y = make_blobs(n_samples=200, centers=2, random_state=0, cluster_std=0.5)
x1 = x[:, 0]
x2 = x[:, 1]
# 可视化数据
plt.scatter(x1[y == 1], x2[y == 1], color='blue', marker='o')
plt.scatter(x1[y == 0], x2[y == 0], color='red', marker='x')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Generated Data')
plt.show()

模型定义

逻辑回归模型的基本形式为:
h ( θ ) = σ ( θ 0 + θ 1 x 1 + θ 2 x 2 ) h(\theta) = \sigma(\theta_0 + \theta_1 x_1 + \theta_2 x_2) h(θ)=σ(θ0+θ1x1+θ2x2)
其中, σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1 是sigmoid函数。

def hypothesis(theta0, theta1, theta2, x1, x2):z = theta0 + theta1 * x1 + theta2 * x2h = torch.sigmoid(z)return h.view(-1, 1)

损失函数

逻辑回归的损失函数通常使用对数损失函数(Log Loss):
J ( h , y ) = − 1 m ∑ i = 1 m [ y ( i ) log ⁡ ( h ( i ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − h ( i ) ) ] J(h, y) = -\frac{1}{m} \sum_{i=1}^{m} \left[ y^{(i)} \log(h^{(i)}) + (1 - y^{(i)}) \log(1 - h^{(i)}) \right] J(h,y)=m1i=1m[y(i)log(h(i))+(1y(i))log(1h(i))]

def J(h, y):return -torch.mean(y * torch.log(h) + (1 - y) * torch.log(1 - h))

模型训练

我们使用PyTorch的Adam优化器来训练模型。训练过程中,我们不断更新模型参数以最小化损失函数。

if __name__ == '__main__':# 数据准备x1 = torch.tensor(x1, dtype=torch.float32)x2 = torch.tensor(x2, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 初始化参数theta0 = torch.tensor(0.0, requires_grad=True)theta1 = torch.tensor(0.0, requires_grad=True)theta2 = torch.tensor(0.0, requires_grad=True)# 优化器optimizer = torch.optim.Adam([theta0, theta1, theta2])# 训练模型for epoch in range(10000):h = hypothesis(theta0, theta1, theta2, x1, x2)loss = J(h, y)loss.backward()optimizer.step()optimizer.zero_grad()if epoch % 1000 == 0:print(f'After {epoch} epochs, the loss is {loss.item():.3f}')# 获取训练后的参数w1 = theta1.item()w2 = theta2.item()b = theta0.item()# 可视化决策边界x = np.linspace(-1, 6, 100)d = -(w1 * x + b) * 1.0 / w2plt.scatter(x1[y == 1], x2[y == 1], color='blue', marker='o')plt.scatter(x1[y == 0], x2[y == 0], color='red', marker='x')plt.plot(x, d, color='green')plt.xlabel('Feature 1')plt.ylabel('Feature 2')plt.title('Decision Boundary')plt.show()

结果分析

输出:

after 0 ,the loss is 0.693
after 1000 ,the loss is 0.188
after 2000 ,the loss is 0.086
after 3000 ,the loss is 0.049
after 4000 ,the loss is 0.031
after 5000 ,the loss is 0.020
after 6000 ,the loss is 0.014
after 7000 ,the loss is 0.009
after 8000 ,the loss is 0.007
after 9000 ,the loss is 0.005

在这里插入图片描述

通过上述代码,我们可以生成二分类数据集并训练逻辑回归模型。训练过程中,损失函数逐渐减小,最终模型能够较好地拟合数据。最终的决策边界将数据集中的两个类别分开。

总结

本文通过一个简单的例子,详细介绍了如何使用Python和PyTorch实现逻辑回归模型。通过生成数据、定义模型、定义损失函数、训练模型和可视化结果,我们展示了逻辑回归的基本流程。希望本文对读者理解逻辑回归有所帮助。

参考资料

  • Scikit-learn官方文档
  • PyTorch官方文档
  • 逻辑回归的数学原理

http://www.ppmy.cn/news/1550847.html

相关文章

C语言——海龟作图(对之前所有内容复习)

一.问题描述 海龟作图 设想有一只机械海龟,他在C程序控制下在屋里四处爬行。海龟拿了一只笔,这支笔或者朝上,或者朝下。当笔朝下时,海龟用笔画下自己的移动轨迹;当笔朝上时,海龟在移动过程中什么也不画。 …

uniapp介入极光推送教程 超级详细

直接按照下面教程操作 一步一步来 很快就能 完成 下面的文章非常详细 ,我就不班门弄斧了 直接上原文链接 https://blog.csdn.net/weixin_52830464/article/details/143823231

【Linux系列】Chrony时间同步服务器搭建完整指南

1. 简介 Chrony是一个用于Linux系统的高效、精准的时间同步工具,通常用于替代传统的NTP(Network Time Protocol)服务。Chrony不仅在系统启动时提供快速的时间同步,还能在时钟漂移较大的情况下进行及时调整,因此广泛应…

No.2 杀戮尖塔Godot复刻2卡牌拖动和状态机1|CardUI|BattleUI

杀戮尖塔中有两种卡 单一目标卡牌和非单一目标卡牌 使用卡牌方法: 如果按住鼠标左键拖动防御卡并将其释放到屏幕中的某个位置,该卡就会被打出另一种方法是鼠标左键单击防御卡,不按下左键,将其拖到屏幕中间,再次单击鼠…

Django websocket 进行实时通信(消费者)

1. settings.py 增加 ASGI_APPLICATION "django_template_v1.routing.application"CHANNEL_LAYERS {"default": {# This example apps uses the Redis channel layer implementation channels_redis"BACKEND": "channels_redis.core.Red…

代码随想录算法训练营第六十天|Day60 图论

Bellman_ford 队列优化算法(又名SPFA) https://www.programmercarl.com/kamacoder/0094.%E5%9F%8E%E5%B8%82%E9%97%B4%E8%B4%A7%E7%89%A9%E8%BF%90%E8%BE%93I-SPFA.html 本题我们来系统讲解 Bellman_ford 队列优化算法 ,也叫SPFA算法&#xf…

Zookeeper学习心得

本人学zookeeper时按照此文路线学的 Zookeeper学习大纲 - 似懂非懂视为不懂 - 博客园 一、Zookeeper安装 ZooKeeper 入门教程 - Java陈序员 - 博客园 Docker安装Zookeeper教程(超详细)_docker 安装zk-CSDN博客 二、 zookeeper的数据模型 ZooKeepe…

免费下载 | 2024年中国网络安全产业分析报告

《2024年中国网络安全产业分析报告》由中国网络安全产业联盟(CCIA)发布,主要内容包括: 前言:强调网络安全是国家安全的重要组成部分,概述了中国在网络安全治理方面的进展和挑战。 网络安全产业发展形势&am…