损失函数:Cross Entropy Loss (交叉熵损失函数)

news/2024/10/22 18:26:58/

损失函数:Cross Entropy Loss (交叉熵损失函数)

  • 前言
  • 相关介绍
  • Softmax函数
    • 代码实例
  • Cross Entropy Loss (交叉熵损失函数)
    • Cross Entropy Loss与BCE loss区别
    • 代码实例

前言

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

相关介绍

损失函数(Loss Function)在机器学习和深度学习中扮演着至关重要的角色,它是一个评估模型预测输出与真实标签之间差异程度的函数。损失函数量化了模型预测错误的程度,并在训练过程中作为优化的目标,模型通过不断地调整内部参数以最小化损失函数的值,从而实现更好的拟合数据和泛化能力。

主要特性与作用:

  1. 量化误差:损失函数将模型预测值与实际目标值之间的差异转化为数值,这样就可以通过数值大小直观地衡量模型的预测效果。

  2. 优化导向:在训练神经网络时,优化算法(如梯度下降法)会根据损失函数的梯度来更新模型参数,使损失函数朝着最小化方向移动。

  3. 种类多样:根据不同的任务和需求,有多种不同的损失函数可供选择。例如,在二分类任务中常用的有二元交叉熵损失函数(Binary Cross-Entropy Loss/BCE Loss),在多分类任务中有softmax交叉熵损失函数,在回归任务中常见的是均方误差(Mean Squared Error/MSE)和绝对误差(Mean Absolute Error/MAE)等。

常见的损失函数包括:

  • 二元交叉熵损失(Binary Cross-Entropy Loss / BCE Loss):适用于二分类问题,衡量的是sigmoid函数输出的概率与真实标签间的距离。

  • 多分类交叉熵损失(Categorical Cross-Entropy Loss):对于多分类问题,每个样本可能属于多个类别之一,使用softmax函数和交叉熵损失。

  • 均方误差(Mean Squared Error / MSE):在回归问题中常用,计算预测值与真实值之差的平方平均。

  • 均方根误差(Root Mean Squared Error / RMSE):MSE的平方根,也是回归任务中的损失函数。

  • Huber损失:一种既能兼顾均方误差又能容忍较大误差的混合损失函数,常用于回归问题中。

  • Dice系数损失(Dice Loss):在图像分割任务中广泛使用,衡量的是预测分割区域与真实分割区域的重叠程度。

  • IoU(Intersection over Union)损失:也是在图像分割领域常用的损失函数,计算的是预测区域与真实区域交集与其并集的比例。

  • Focal Loss:在目标检测中应对类别不平衡问题的损失函数,对易分类的样本给予较小的权重,强调难分类样本的训练。

每种损失函数都有其适用的情境和优缺点,选择合适的损失函数是优化模型性能的关键因素之一。

交叉熵(Cross-Entropy)之所以能够用于分类问题,是因为它能够很好地衡量模型预测的概率分布与实际标签分布之间的相似度,而且它拥有几个非常适合分类任务的重要特性:

  1. 信息论基础:交叉熵源于信息论中的概念,表示一个概率分布 (p) 与另一个概率分布 (q) 的差异。在分类问题中,我们可以把 (p) 视为真实数据的标签分布,(q)视为模型预测的概率分布。交叉熵可以衡量模型预测概率与实际类别标签之间的信息差异。

  2. 最大似然估计的自然延伸:在机器学习中,我们通常倾向于最大化模型对数据的似然性,即模型预测给定数据标签的概率。交叉熵损失函数实际上是负对数似然函数在多项式分布(对于多分类问题)或伯努利分布(对于二分类问题)下的特殊情况,通过最小化交叉熵损失,相当于最大化数据的对数似然性。

  3. 梯度稳定性:交叉熵损失函数是连续且可微的,其梯度容易计算且对于大多数情况是有意义的。这意味着在训练过程中,模型可以根据损失函数的梯度进行有效的参数更新。

  4. 稀疏性惩罚:对于多分类问题,softmax函数与交叉熵损失组合使用时,不仅鼓励模型正确预测每个样本的类别,同时也通过归一化机制惩罚了预测概率分布的不均匀性,即模型不能过于肯定任何一个错误类别。

  5. 处理多类别和二类别问题:交叉熵既可以用于处理二分类问题(通过二元交叉熵,Binary Cross-Entropy),也可以处理多分类问题(通过多类别交叉熵,Multiclass
    Cross-Entropy)。在二分类问题中,通常搭配Sigmoid函数输出概率;在多分类问题中,通常配合Softmax函数生成类别概率分布。

总的来说,交叉熵损失函数因其良好的理论基础、优化目标清晰以及在实践中的优秀表现,成为了分类问题中最常用的损失函数之一。

Softmax函数

Softmax函数是深度学习和机器学习中广泛使用的激活函数,特别是在多分类问题中。它的目的是将一个线性变换的输出(通常称为logits)映射为一个概率分布,使得所有类别的概率总和为1,每个类别的概率都在0到1之间。

Softmax函数的形式:

对于一个向量 ( z ) ,其中包含每个类别的原始得分(logits),Softmax函数的计算公式如下:

s o f t m a x ( z ) i = e z i ∑ j = 1 K e z j softmax(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}} softmax(z)i=j=1Kezjezi

其中:

  • ( K ) 表示类别总数。
  • ( z_i ) 表示第 ( i ) 个类别的得分。
  • ( softmax(z)_i ) 表示第 ( i ) 个类别的归一化概率。

整个Softmax函数的结果是一个概率分布向量,其中每个元素都是原得分经过指数函数变换后再除以所有得分指数函数值之和,因此所有元素的和为1。

Softmax函数的特性:

  1. 概率性质:Softmax函数确保输出的每个元素都是非负数,并且所有元素的和为1,满足概率分布的要求。
  2. 竞争性:Softmax函数会使得分最高的类别获得最大的概率值,其余类别的概率按比例递减,形成了一种“赢家通吃”的效应。
  3. 平滑连续:由于使用了指数函数和平滑的除法运算,Softmax函数输出是平滑且连续的,便于在训练过程中梯度的计算和传播。

应用场景

在深度学习的多分类问题中,例如图像分类、文本分类等任务,Softmax函数通常与交叉熵损失函数一起使用。模型最后一层通常会产生一个logits向量,接着通过Softmax函数得到每个类别的概率,最后计算与实际标签之间的交叉熵损失,以此指导模型参数的更新。

代码实例

在PyTorch中,你可以直接使用torch.softmax()函数来实现Softmax操作。下面是一个简单的实例:

import torch# 假设我们有一个代表logits的张量
logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])# 使用torch.softmax函数计算Softmax值
probs = torch.softmax(logits, dim=1)print(probs)
'''
tensor([[0.0900, 0.2447, 0.6652],[0.0900, 0.2447, 0.6652]])
'''

上述代码中,logits是一个2x3的张量,代表两个样本的三个类别的原始得分。dim=1表示我们在每个样本的类别间计算Softmax,也就是对每一行进行操作。执行torch.softmax()后,probs张量将包含每个样本各类别的归一化概率。

注意,如果你正在训练一个多分类模型并且使用了nn.CrossEntropyLoss()损失函数,通常不需要单独调用torch.softmax(),因为该损失函数内部已经包含了对logits计算Softmax的过程。在多数情况下,你只需将模型的原始输出(logits)传递给损失函数,并配合真实类别标签即可。

Cross Entropy Loss (交叉熵损失函数)

在这里插入图片描述

nn.CrossEntropyLoss是PyTorch中用于多分类问题的一种损失函数,特别适用于输出层是softmax激活函数后的分类任务。它结合了softmax函数和交叉熵损失(Cross-Entropy Loss)的操作,简化了模型训练过程中的计算步骤和代码实现。

基本概念:

  • 交叉熵损失(Cross-Entropy Loss)源于信息论中的熵概念,用于衡量两个概率分布之间的差异。在机器学习和深度学习中,它用来量化模型预测的概率分布与真实标签分布之间的差距。

  • softmax函数:在多分类问题中,softmax函数将模型的线性输出(logits)转换为一个概率分布,确保所有类别的概率和为1。softmax函数的输出可以用作模型预测的概率分布。

nn.CrossEntropyLoss的工作方式:

  • PyTorch中的nn.CrossEntropyLoss接收两个输入:

    • input:模型的原始输出(logits),通常是未经过softmax激活的张量。
    • target:真实的一维标签张量,包含了每个样本所属类别的索引,通常采用LongTensor类型。
  • 内部处理流程

    • 对于每个样本,首先计算其对应的softmax概率分布。
    • 然后,根据真实标签计算交叉熵损失。损失是对每个样本的损失值进行平均得到的,如果没有特殊指定,损失默认会在批次(batch)层面求平均。
  • 损失函数计算公式

    • 对于单个样本,交叉熵损失是 -∑(yi * log(pi)),其中 yi 是实际标签的one-hot编码(在实际情况中,由于标签是索引形式,nn.CrossEntropyLoss内部会处理one-hot编码),pi 是模型预测的该类别概率。
    • 对于整个批次,损失则是各样本损失的平均。

Cross Entropy Loss与BCE loss区别

  • 关于BCE Loss(二元交叉熵损失函数)的相关知识,可查阅损失函数:BCE Loss(二元交叉熵损失函数)、Dice Loss(Dice相似系数损失函数)

CrossEntropyLossBCELoss 都是 PyTorch 中用于监督学习分类任务的损失函数,它们分别适用于不同的分类场景:

BCELoss (Binary Cross Entropy Loss)

  • BCELoss 是二元交叉熵损失函数,专门用于二分类问题,即输出只有两类(0或1,正面或负面,真或假等)。
  • 使用 BCELoss 时,模型的输出一般是通过 Sigmoid 函数得到的概率值,介于0和1之间。
  • 计算公式为 -y * log(p) - (1-y) * log(1-p),其中 y 是真实的标签(0或1),p 是模型预测的概率。
  • 输入要求是经过Sigmoid激活函数之后的输出张量和相应的真实标签张量,二者形状必须相同。

CrossEntropyLoss (Multinomial Cross Entropy Loss 或者 Softmax Cross Entropy Loss)

  • CrossEntropyLoss 适用于多分类问题,它可以处理任何数量的类别,不仅仅是二分类。
  • 对于多分类问题,模型的输出通常是一个 logits(未归一化的预测值),然后CrossEntropyLoss内部会先通过Softmax函数将其转换为概率分布,然后再计算交叉熵。
  • 使用 CrossEntropyLoss 时,不需要手动在输出层之前添加Sigmoid或Softmax函数,因为它已经包含了Softmax运算步骤。
  • 它结合了Softmax函数和交叉熵损失的功能,简化了多分类任务的训练流程,其计算公式基于交叉熵和类别间的互斥性(即对于每个样本,所有类别的概率之和为1)。
  • 输入要求是未经Softmax激活函数处理的logits张量和one-hot编码形式的真实标签张量。

总结来说,两者的主要区别在于:

  • BCELoss用于二分类任务,而CrossEntropyLoss适用于多分类任务。
  • BCELoss前接Sigmoid,CrossEntropyLoss前接Softmax(但这一步在使用CrossEntropyLoss时由损失函数内部自动完成)。
  • BCELoss处理的是二元概率分布,而CrossEntropyLoss处理的是多类别概率分布。

代码实例

import torch
import torch.nn as nn# 假设模型输出和真实标签
output_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])  # 假设输出是两样本的三个类别的logits
targets = torch.tensor([1, 2])  # 假设第一样本是第二类,第二样本是第三类# 创建交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(output_logits, targets)print(loss.item())  # 输出损失值 # 0.9076058864593506

在上述代码中,nn.CrossEntropyLoss()函数内部处理了softmax激活和交叉熵损失计算,直接返回了模型预测与真实标签之间的交叉熵损失。

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

http://www.ppmy.cn/news/1423236.html

相关文章

FreeRTOS_day3

1.总结任务调度算法之间的区别,重新实现一遍任务调度算法的代码。 抢占式调度:高优先级的任务可以打断低优先级的任务执行 时间片轮转:相同优先级的任务有相同的时间片(1ms),时间片耗尽任务会强制退出 协…

OpenHarmony轻量系统开发【12】OneNET云接入

12.1 OneNET云介绍 通常来说,一个物联网产品应当包括设备、云平台、手机APP。我将在鸿蒙系统上移植MQTT协议、OneNET接入协议,实现手机APP、网页两者都可以远程(跨网络,不是局域网的)访问开发板数据,并控制…

集群开发学习(一)(安装GO和MySQL,K8S基础概念)

完成gin小任务 参考文档: https://www.kancloud.cn/jiajunxi/ginweb100/1801414 https://github.com/hanjialeOK/going 最终代码地址:https://github.com/qinliangql/gin_mini_test.git 学习 1.安装go wget https://dl.google.com/go/go1.20.2.linu…

攻防世界---reverse---1000Click

1.方法一:直接点击1000(可能是我太闲了,真的点了1000次) 2.方法二 1.将exe文件放在IDA中分析 使用shiftfnf12,查看字符串,使用ctrlf搜索字符串 2.任意打开一个flag,里面有许多flag&#xff0c…

Stability AI 发布 SD3 API:开启人工智能新篇章

文章目录 1.Stable Diffusion 3 API开放了! 2.Stability AI Document地址3.获取API Key4.API方式调用SD3出图接口地址接口请求规范接口请求响应结果 5.Stable Diffusion 3.0、Stable Image Core、Fooocus 2.3.1、MidJounery效果查看 1.Stable Diffusion 3 API开放了! Stabilit…

支付宝支付之SpringBoot整合支付宝创建自定义支付二维码

文章目录 自定义支付二维码pom.xmlapplication.yml自定义二维码类AlipayService.javaAlipayServiceImpl.javaAlipayController.javaqrCode.html 自定义支付二维码 继&#xff1a;SpringBoot支付入门 pom.xml <dependency><groupId>org.springframework.boot<…

攻防演练,作为蓝方,centos的操作系统,怎么查看是不是有隐藏用户,有没有获取权限

在攻防演练中&#xff0c;作为蓝方&#xff08;防守方&#xff09;&#xff0c;检查 CentOS 操作系统中是否存在隐藏的用户以及是否有未授权的权限获取是非常重要的。以下是一些检查步骤和工具的使用方法&#xff0c;这些可以帮助你确保系统的安全和完整性。 1. 检查系统用户 …

案例与脚本实践:DolphinDB 轻量级实时数仓的构建与应用

DolphinDB 高性能分布式时序数据库&#xff0c;具有分布式计算、事务支持、多模存储、以及流批一体等能力&#xff0c;非常适合作为一款理想的轻量级大数据平台&#xff0c;轻松搭建一站式的高性能实时数据仓库。 本教程将以案例与脚本的方式&#xff0c;介绍如何通过 Dolphin…