深度学习 -- 逻辑回归 PyTorch实现逻辑回归

news/2024/12/28 14:46:31/

前言

线性回归解决的是回归问题,而逻辑回归解决的是分类问题,这两种问题的区别是前者的目标属性是连续的数值类型,而后者的目标属性是离散的标称类型。

可以将逻辑回归视为神经网络的一个神经元,因此学习逻辑回归能帮助理解神经网络的工作原理。

什么是逻辑回归?

逻辑回归是一种广义的线性回归分析模型,是监督学习的一种重要方法,主要用于二分类问题,但也可以用于多分类问题。

逻辑回归的主要思想是,对于一个二分类问题,先根据样本数据计算出每个特征的概率,然后根据这些概率计算出每个样本属于每个类别的概率,最后根据这些概率来预测测试集数据属于每个类别的概率。

逻辑回归介绍

逻辑回归的推导过程与计算方式类似于回归的过程,但实际上主要是用来解决二分类问题。 在逻辑回归中,输入数据集D被分成两个部分:一类是训练集D_train,一类是测试集D_test。在每次训练时,我们使用一部分数据来训练模型,然后使用另一部分数据来评估模型的性能。 在测试时,我们使用所有的数据来评估模型的性能。

在实际使用中,逻辑回归可以使用各种不同的损失函数来最小化训练数据集和测试数据集之间的均方误差。常见的逻辑回归损失函数包括均方误差损失函数、交叉熵损失函数、对数损失函数等。

Sigmoid函数

Sigmoid函数是一个在生物学中常见的S型函数,也称为S型生长曲线。在信息科学中,由于其单增以及反函数单增等性质,Sigmoid函数常被用作神经网络的激活函数,将变量映射到0,1之间。

在神经网络中经常使用Sigmoid函数作为激活函数,因为它能够有效的输出0-1之间的概率。

代价函数

代价函数(Cost Function)是深度学习模型中用于评估模型性能的函数,它是优化算法的目标函数。代价函数通常定义为损失函数(Loss Function)的平方,这样可以简单地通过计算损失函数值来评估模型的性能。

在深度学习中,代价函数通常是指均方误差(MSE)损失函数,因为均方误差是深度学习中最常用的损失函数之一。均方误差损失函数定义为:

J(y_true, y_pred) = 1/N - ∑i=1N(yi - y_i)^2

其中,y_true是真实标签,y_pred是模型预测的标签,N是样本数量,yi是真实标签对应的样本值。

代价函数的作用是评估模型的性能,其中J(y_true, y_pred)表示真实标签和模型预测标签之间的均方误差。优化算法会在代价函数上进行最小化操作,以最小化损失函数值。

除了均方误差损失函数,还有其他类型的损失函数,如交叉熵损失函数、对数损失函数等,它们在不同的场景下可能更有效或更适合。

逻辑回归在PyTorch中的实现

1 从头开始实现一个逻辑回归

  • 首先定义一个逻辑回归模型
import torchdef sigmoid(z):'''s型激活函数'''g = 1 / (1+torch.exp(-z))return gdef model(x,w,b):'''逻辑回归模型'''return sigmoid(x.mv(w)+b)# w是向量,b是标量,而x是矩阵,使用x.mv(w) 可以实现矩阵x与向量w的相乘

注意:这里w是向量,b是标量,而x是矩阵,使用x.mv(w) 可以实现矩阵x与向量w的相乘

  • 然后定义损失函数和损失函数求导
# 定义损失函数
def loss_fn(y_pred,y):'''损失函数'''loss = - y.mul(y_pred.view_as(y)) - (1-y).mul(1-y_pred.view_as(y))return loss.mean()# 损失函数求导
def grad_loss_fn(y_pred,y):'''损失函数求导'''return y_pred.view_as(y)-y
  • 接着定义一个梯度函数
# 定义梯度函数
def grad_fn(x,y,y_pred):'''梯度函数'''grad_w = grad_loss_fn(y_pred,y)*xgrad_b = grad_loss_fn(y_pred,y)return torch.cat((grad_w.mean(dim=0),grad_b.mean().unsqueeze(0)),0)
  • 模型训练函数
# 模型训练函数
def model_training(x,y,n_epochs,learning_rate,params,print_params=True):'''训练'''for epoch in range(1,n_epochs+1):w,b = params[:-1],params[-1]# 前向传播y_pred = model(x,w,b)# 计算损失loss = loss_fn(y_pred,y)# 梯度grad = grad_fn(x,y,y_pred)# 更新参数params -= learning_rate*gradif epoch == 1 or epoch%10 == 1:print('轮次:%d,\t损失:%f'%(epoch,float(loss)))if print_params:print(f'参数:{params.detach().numpy()}')print(f'梯度:{grad.detach().numpy()}')return params
  • 最后定义main函数
if __name__ == '__main__':# 随机生成数据x = torch.randn(2,2)y = torch.tensor([[1.,0.],[0.,1.]])# 模型参数初始化w = torch.zeros(2)  # tensor([0., 0.])b = torch.zeros(1)  # tensor([0.])params = model_training(x=x,y=y,n_epochs=500,learning_rate=0.1,params=torch.tensor([0.0,0.0,0.0]))print(params.numpy())

http://www.ppmy.cn/news/69502.html

相关文章

Java 中的访问修饰符有什么区别?

Java 中的访问修饰符用于控制类、类的成员变量和方法的访问权限,主要有以下四种: public:公共访问修饰符,可以被任何类访问。public 修饰的类、成员变量和方法可以在任何地方被访问到。 protected:受保护的访问修饰符…

【C++】8.编译:CMake工具入门

😏*★,*:.☆( ̄▽ ̄)/$:*.★* 😏这篇文章主要介绍CMake工具的入门使用。————————————————学其所用,用其所学。——梁启超————————————————— 欢迎来到我的博客,一起学习知识…

【Hello Algorithm】归并排序及其面试题

作者:小萌新 专栏:算法 作者简介:大二学生 希望能和大家一起进步 本篇博客简介:介绍归并排序和几道面试题 归并排序及其面试题 归并排序归并排序是什么归并排序的实际运用归并排序的迭代写法归并排序的时间复杂度 归并排序算法题小…

2023-04-30:用go语言重写ffmpeg的resampling_audio.c示例,它实现了音频重采样的功能。

2023-04-30:用go语言重写ffmpeg的resampling_audio.c示例,它实现了音频重采样的功能。 答案2023-04-30: resampling_audio.c 是 FFmpeg 中的一个源文件,其主要功能是实现音频重采样。 音频重采样是指将一段音频数据从一个采样率…

mysql事务及搜索引擎

mysql事务后半部分 加快查询速度索引会自动排序,(升序) select * from t1;全盘扫描 where可以索引查找show create table 索引是一个排序的列表,包含字段值和相应行数据的物理地址 事务是一种机制,一个…

Spark大数据处理讲课笔记3.5 RDD持久化机制

文章目录 零、本讲学习目标一、RDD持久化(一)引入持久化的必要性(二)案例演示持久化操作1、RDD的依赖关系图2、不采用持久化操作3、采用持久化操作 二、存储级别(一)持久化方法的参数(二&#x…

android log的使用

现在在分析一个android netd的问题,只要一开启热点, for (String ifname : added) {try {Log.d(TAG, "TetheredState, processMessage CMD_TETHER_CONNECTION_CHANGED, add mIfaceName " mIfaceName " ifname " ifname );mNetd.…

etcd的Watch原理

在 Kubernetes 中,各种各样的控制器实现了 Deployment、StatefulSet、Job 等功能强大的 Workload。控制器的核心思想是监听、比较资源实际状态与期望状态是否一致,若不一致则进行协调工作,使其最终一致。 那么当你修改一个 Deployment 的镜像…