交叉熵损失函数原理详解

news/2025/1/22 8:35:59/

交叉熵损失函数原理详解

在学习pytorch的神经网络模型里,经常用到交叉熵损失函数(CrossEntropy Loss),只知道它是分类问题中经常使用的一种损失函数,对于其内部的原理总是模模糊糊,而且一般使用交叉熵作为损失函数时,在模型的输出层总会接一个softmax函数,至于为什么要怎么做也是不懂,所以专门花了一些时间打算从原理入手,搞懂它,故在此写一篇博客进行总结,以便以后翻阅。

交叉熵简介

交叉熵是信息论中的一个重要概念,主要用于度量两个概率分布间的差异性,要理解交叉熵,需要先了解下面几个概念。

信息量

信息奠基人香农(Shannon)认为“信息是用来消除随机不确定性的东西”,也就是说衡量信息量的大小就是看这个信息消除不确定性的程度。

太阳从东边升起”,这条信息并没有减少不确定性,因为太阳肯定是从东边升起的,这是一句废话,信息量为0。
全国两会:十四届全国人大一次会议【3月5日】;全国政协十四届一次会议【3月4日】”,从直觉上来看,这句话具有很大的信息量。因为中国队进入世界杯的不确定性因素很大,而这句话消除了进入世界杯的不确定性,所以按照定义,这句话的信息量很大。
根据上述可总结如下:信息量的大小与信息发生的概率成反比。概率越大,信息量越小。概率越小,信息量越大。
设某一事件发生的概率为P(x),其信息量表示为:
I ( x ) = − l o g ( p ( x ) ) I(x) = -log^{(p(x))} I(x)=log(p(x))
其中I(x)表示信息量,这里 l o g ( p ( x ) ) log^{(p(x))} log(p(x))表示以e为底的自然对数。

信息熵

信息熵也被称为,用来表示所有信息量的期望。

期望是试验中每次可能结果的概率乘以其结果的总和。

所以信息量的熵可表示为:(这里的X是一个离散型随机变量)
H ( X ) = − ∑ P ( x i ) l o g ( P ( X i ) ) ( X = x 1 , x 2 , . . . , x n ) H(X) = - \sum P(x_i)log^{(P(X_i))} \quad (X =x_1,x_2,...,x_n) H(X)=P(xi)log(P(Xi))(X=x1,x2,...,xn)
使用明天的天气概率来计算其信息熵:

序号事件概率P信息量
1明天是晴天0.5 − l o g 0.5 -log^{0.5} log0.5
2明天是雨天0.25 − l o g 0.25 -log^{0.25} log0.25
3明天是多云0.25 − l o g 0.25 -log^{0.25} log0.25

即:
H ( X ) = − ( 0.5 ∗ l o g 0.5 + 0.25 ∗ l o g 0.25 + 0.25 ∗ l o g 0.25 ) H(X) = -(0.5 * log^{0.5} + 0.25*log^{0.25} + 0.25 * log^{0.25}) H(X)=(0.5log0.5+0.25log0.25+0.25log0.25)

相对熵(KL散度)

如果对于同一个随机变量X有两个单独的概率分布p(x)q(x),则我们可以使用KL散度来衡量这两个概率分布之间的差异。
D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) D_{KL}(p||q) = \sum_{i=1}^n p(x_i)log^{(p(x_i)\over q(x_i))} DKL(p∣∣q)=i=1np(xi)logq(xi))(p(xi)
在机器学习中,常常使用 p(x) 表示样本的真实分布,q(x) 表示来表示模型所预测的分布,
KL散度越小,表示p(x)q( x ) 的分布更加接近,可以通过反复训练q(x) 来使 q(x) 的分布逼近p(x)

交叉熵

交叉熵的推导:

首先将KL散度公式拆开:
D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) D_{KL}(p||q) = \sum_{i=1}^n p(x_i)log^{(p(x_i)\over q(x_i))} DKL(p∣∣q)=i=1np(xi)logq(xi))(p(xi)
= ∑ i = 1 n p ( x i ) ( l o g p ( x i ) − l o g q ( x i ) ) = \sum_{i=1}^n p(x_i) (log^{p(x_i)} - log^{q(x_i)} ) =i=1np(xi)logp(xi)logq(xi)
= ∑ i = 1 n p ( x i ) l o g p ( x i ) − ∑ i = 1 n p ( x i ) l o g q ( x i ) ) = \sum_{i=1}^n p(x_i) log^{p(x_i)} - \sum_{i=1}^n p(x_i) log^{q(x_i)} ) =i=1np(xi)logp(xi)i=1np(xi)logq(xi)
= − H ( x ) + [ − ∑ i = 1 n p ( x i ) l o g q ( x i ) ) ] = -H(x) + [- \sum_{i=1}^n p(x_i) log^{q(x_i)} )] =H(x)+[i=1np(xi)logq(xi)]
前者H(X), 即H(p(x))表示信息熵,后者即为交叉熵,KL散度 = 交叉熵 - 信息熵
所以,
交叉熵公式表示为:
H ( x ) = − ∑ i = 1 n p ( x i ) l o g q ( x i ) H(x) = - \sum_{i=1}^n p(x_i) log^{q(x_i)} H(x)=i=1np(xi)logq(xi)
在机器学习训练网络时,输入数据与标签常常已经确定,那么真实概率分布p(x) 也就确定下来了,所以信息熵在这里就是一个常量。由于KL散度的值表示真实概率分布p(x)与预测概率分布q(x) 之间的差异,值越小表示预测的结果越好,所以需要最小化KL散度,而交叉熵等于KL散度加上一个常量(信息熵),且公式相比KL散度更加容易计算,所以在机器学习中常常使用交叉熵损失函数来计算loss。

交叉熵在多分类问题中的应用

在线性回归问题中,常常使用MSE(Mean Squared Error)作为loss函数,而在分类问题中常常使用交叉熵作为loss函数。


http://www.ppmy.cn/news/56945.html

相关文章

【Python】操作MySQL

一、Python 操作 Mysql的方式 Python 操作 Mysql 主要包含下面 3 种方式: Python-MySql Python-MySql 由 C 语法打造,接口精炼,性能最棒;但是由于环境依赖多,安装复杂,已停止更新,仅支持 Python…

密码学【java语言】初探究

文章目录 前言一 密码学1.1 古典密码学1.1.1 替换法1.1.2 移位法1.1.3 古典密码破解方式 二 近代密码学2.1 现代密码学2.1.1 散列函数2.1.2 对称密码2.1.3 非对称密码 二 凯撒加密的实践2.1 基础知识:ASCII编码2.2 ascii编码演示2.3 凯撒加密和解密实践2.4 频率分析…

多项式加法(用 C 语言实现)

目录 一、多项式的初始化 二、多项式的创建 三、多项式的加法 四、多项式的输出 五、清除链表 六、主函数 用链表实现多项式时,每个链表节点存储多项式中的一个非零项,包括系数(coef)和指数(exp)两个…

权限提升:漏洞探针.(Linux系统)

权限提升:漏洞探针. 权限提升简称提权,由于操作系统都是多用户操作系统,用户之间都有权限控制,比如通过 Web 漏洞拿到的是 Web 进程的权限,往往 Web 服务都是以一个权限很低的账号启动的,因此通过 Webshel…

python-使用Qchart总结3-绘制曲线图

1.将画好的图表关联 解释说明图 2.新建一个文件画曲线图,并关联到UI的py文件上,上代码 import sys from PyQt5.Qt import * from PyQt5.QtChart import QChartView, QChart, QValueAxis, QSplineSeries from PyQt5.QtGui import QPainter, QColor, QFon…

5月2日第壹简报,星期二,农历三月十三

5月2日第壹简报,星期二,农历三月十三坚持阅读,静待花开1. “港车北上”政策公布:6月1日起接受申请,7月1日起可驶入广东,将惠及45万香港车主。2. 全球女性第一人!董红娟登顶全部14座8000米级高峰…

【代码随想录】刷题Day14

递归实现的一些理解 1.如果是链表的遍历其实不需要怎么思考;无非就是先定参数然后考虑是先操作后遍历还是先走到底再操作。 包括我之前在写链表的节点删除其实核心思路就是由于链表前面删除后面找不到的原理,以至于我们需要走到链表的底部再进行操作。 2…

js 特殊对象 - String对象

1.概述 字符串本质是字符数组,所以也是对象,但是单独数据类型为了好用列一个string 在JS中为我们提供了三个包装类,通过这三个包装类可以将基本数据类型的数据转换为对象 String():可以将基本数据类型字符串转换为String对象 Num…