机器学习之交叉熵

ops/2024/12/15 10:27:04/

交叉熵(Cross-Entropy)是机器学习中用于衡量预测分布与真实分布之间差异的一种损失函数,特别是在分类任务中非常常见。它源于信息论,反映了两个概率分布之间的距离。


交叉熵的数学定义

对于分类任务,假设我们有:

  • 一个真实的分布 y,用独热编码表示,例如 y=[0,1,0] 表示属于第二类。
  • 一个预测的概率分布\hat{y},例如 \hat{y} = [0.1, 0.7, 0.2],表示模型预测属于各类的概率。

交叉熵的公式为:

其中:

  • yi是真实分布中第 i 类的值(独热编码下只有一个为 1,其余为 0)。
  • \hat{y}_i 是模型预测的第 i 类的概率。

由于 y 是独热编码,交叉熵可以简化为:

其中 c 是真实类别的索引。


交叉熵的直观理解

  1. 信息论解释

    • 交叉熵可以理解为用预测分布\hat{y} 去编码真实分布 y 的代价。
    • 如果预测越接近真实分布(即预测概率\hat{y}_c 越接近 1),交叉熵越小,模型表现越好。
  2. 惩罚机制

    • 如果模型的预测概率 \hat{y}_c 很低(接近 0),交叉熵会给出很大的惩罚。
    • 这促使模型更自信地预测正确类别。

交叉熵的应用场景

  1. 二分类问题: 对于二分类任务,真实标签 y∈{0,1},模型预测 \hat{y} \in [0, 1]。交叉熵损失为:

  2. 多分类问题: 对于 K 类分类任务,交叉熵损失为:

    其中 y_k 表示第 k 类的真实标签,\hat{y}_k 表示模型对第 k 类的预测概率。

  3. 目标检测和语义分割: 交叉熵通常与其他损失(如 IoU、Dice Loss)结合使用,以处理多任务学习。


交叉熵的优点

  1. 数学性质优良:损失函数连续且可微,适合梯度下降优化。
  2. 自然适用于概率分布:直接用概率度量模型的预测质量。
  3. 对错误预测的敏感性:能有效惩罚错误分类,提高模型对分类任务的优化效果。

交叉熵的缺点

  1. 对预测不平衡的敏感性

    • 如果某些类别的样本数很少,模型可能忽视这些类别。
    • 解决方法:可以结合加权交叉熵(Weighted Cross-Entropy)。
  2. 对异常值的敏感性:当预测概率非常接近 0 时,交叉熵的惩罚会非常大,可能导致数值不稳定。


交叉熵与其它损失的关系

  1. 与均方误差(MSE)

    • MSE 更适合回归任务,而交叉熵适合分类任务。
    • 对于分类任务,MSE 可能导致梯度消失,影响优化效果。
  2. 与 KL 散度:交叉熵是 KL 散度的一部分,衡量预测分布与真实分布的差异。


实现示例

二分类问题的交叉熵损失(Python + PyTorch)
import torch
import torch.nn as nn# 假设真实标签和预测概率
y_true = torch.tensor([1, 0, 1], dtype=torch.float32)  # 真实标签
y_pred = torch.tensor([0.8, 0.2, 0.6], dtype=torch.float32)  # 预测概率# 定义二分类交叉熵损失
loss_fn = nn.BCELoss()
loss = loss_fn(y_pred, y_true)
print(f"Binary Cross-Entropy Loss: {loss.item():.4f}")
多分类问题的交叉熵损失
# 假设真实标签和预测概率
y_true = torch.tensor([1, 0, 2])  # 真实标签(类别索引)
y_pred = torch.tensor([[0.3, 0.6, 0.1],[0.1, 0.2, 0.7],[0.8, 0.1, 0.1]])  # 预测概率# 定义多分类交叉熵损失
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(y_pred, y_true)
print(f"Multi-class Cross-Entropy Loss: {loss.item():.4f}")

交叉熵是分类任务中的核心损失函数之一,其优异的性质和强大的优化能力使其在机器学习的各个领域得到了广泛应用。


http://www.ppmy.cn/ops/142066.html

相关文章

Scala测试

implicit class StrongString(str: String) {def isPhone: Boolean {val reg "1[3-9]\\d{9}".rreg.matches(str)}}def main(args: Array[String]): Unit {val str: String "18888488488"// 需求:给字符串补充一个功能isPhone,判…

【Python系列】异步 Web 服务器

???欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学习,不断总结,共同进步,活到老学到老…

es的join是什么数据类型

在 Elasticsearch 中,parent 并不是一个独立的数据类型,而是与 join 数据类型一起使用的一个概念。join 数据类型用于在同一个索引中建立父子文档之间的关系,允许你在一个索引内表示层级结构或关联关系。通过 join 字段,你可以定义不同类型的文档(如父文档和子文档),并指…

【前端面试】随机、结构赋值、博弈题

解构赋值(Destructuring Assignment)是 JavaScript ES6 引入的一项非常有用的特性,它允许我们快速地从数组或对象中提取值,并将它们赋给变量。这种方式使得代码更加简洁、易读,并且能够减少重复的访问和赋值操作。 1.…

基于Spring Boot + Vue的摄影师分享交流社区的设计与实现

博主介绍:java高级开发,从事互联网行业六年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有实…

Android-ImagesPickers 拍照崩溃优化

Android-ImagesPickers 作为老牌图片选择器,帮助了很多牛马宝宝,刚好最近用到了多相册选择以及拍照,可能是高版本机型问题,导致拍照后就闪退 原作者文章以及git Android实用视图动画及工具系列之九:漂亮的图片选择器…

配置mysqld(读取选项内容,基本配置),数据目录(配置的必要性,目录下的内容,具体文件介绍,修改配置)

目录 配置mysqld 读取选项内容 介绍 启动脚本 基本配置 内容 端口号 数据目录的路径 配置的必要性 配置路径 mysql数据目录 具体文件 修改配置时 权限问题 配置mysqld 读取选项内容 介绍 会从[mysqld] / [server] 节点中读取选项内容 优先读取[server] 虽然服务…

使用html2canvas实现前端截图

一、主要功能 网页截图:html2canvas通过读取DOM结构和元素的CSS样式,在客户端生成图像,不依赖于服务端的渲染。它可以将指定的DOM元素渲染为画布(canvas),并生成图像。多种输出格式:生成的图像…