【AI面试】CrossEntropy Loss 、Balanced Cross Entropy、 Dice Loss 和 Focal Loss 横评对比

news/2025/1/15 23:03:57/

样本不均衡问题一直是深度学习领域一个不可忽略的问题,常说的长尾效应,说的就是这个问题。一类占据了主导地位,导致其他类无论怎么优化,都不能好转。

无论是纯纯的分类任务,还是稍微复杂一些的目标检测任务和分割任务,这种不均衡都一直存在,表现为三种形式:

  1. 类与类之间的数量不均衡;
  2. 前景与背景之间数量不均衡;
  3. 简单样本和难样本之间的不均衡。

如何在计算损失的时候,得到的最终的loss值,对各个数量不均衡的部分,回归的效果更佳均衡,让模型在更新的时候,不会侧重于某一面,而忽略了另一面的问题?值得探讨

本文就对类似这种不均衡的问题展开讨论,这也是在工作中常常会遇到的问题,同时,也是面试官最喜欢问到的问题。


与此同时,在UNet分割任务中,把目标分割任务的mask,转化为对像素点的分类任务。所以在计算损失的时候,论文里面的损失函数采用的就是交叉熵损失函数。

在后续的损失改进中,多引入dice lossfocal loss。我们就从交叉熵损失函数开始,探讨下它为什么可以应用在分割任务中。

一、CE(CrossEntropy)交叉熵损失

在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。

在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层,对输入pd进行归一化操作。

nn.CrossEntropyLoss()的pytorch文档计算公式,如下:

ce
其中:

  • C:number of classes
  • N:batch size
  • wc:对应class的weight权重

PytorchCrossEntropyLoss()函数的主要是将softmax+log+NLLLoss合并到一块得到的结果。(下面会给出代码验证,参考:https://blog.csdn.net/qq_39506862/article/details/126625515)

  1. Softmax后的数值都在0~1之间,所以ln之后值域是负无穷到0;
    • log中0-1的输入,输出是0-正无穷;
    • 乘以w加负号,值域是负无穷到0。
  2. 然后将Softmax之后的结果取log,将乘法改成加法,减少计算量,同时保障函数的单调性;
  3. NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,去掉负号,再求均值。

log以e为底,nn.CrossEntropyLoss(pd, gt),有两个参数:

  • 第一个参数:pd 是网络的最后一层的输出,其shape为[batchsize, class](函数要求第一个参数,也就是最后一层的输出为二维数据,每个向量中的值为不同种类的概率值)
  • 第二个参数:gt 是传入的标签,也就是某个类别的索引值,在上面公式没有参与计算。batch_size如果是1,那么就只有一个数字,0,1或者2,表示的就是此时这个样本对应的真实类别,如果为batch_size是2,那么就有两个数字,例如(0,1),分别表示这两条样本对应的真实类别。

对于(minibatch,C) 的情况:例如,输入(也就是最后一层的输出)为pd=[[4,8,3]]shape=(1,3),batchsize=1,class=3

第一步, 我们首先计算一下各个类别的损失,这里的类似是3,默认各类的权重都是1,对类别0、1、2类别对应的损失分别是:

    loss(pd, 0) # class=0=-pd[0]+log(exp(pd[0])+exp(pd[1])+exp(pd[2]))=-4+log(exp(4)+exp(8)+exp(3))=-4+8.0247=4.0247loss(pd, 1) # class=1=-pd[1]+log(exp(pd[0])+exp(x[1])+exp(x[2]))=-8+log(exp(4)+exp(8)+exp(3))=-8+8.0247=0.0247loss(pd, 2) # class=2=-pd[2]+log(exp(pd[0])+exp(pd[1])+exp(pd[2]))=-3+log(exp(4)+exp(8)+exp(3))=-3+8.0247=5.0247

第二步, 用第二个参数(一维数据)也就是根据类别对应的索引值,在上面结果中选择对应的损失,例如若网络的输出为x=[[4,8,3]],而对应的标签为1,则得到损失函数loss=0.0247,其他类不累积计算。

one-hot里面,也就是为one的计算损失,其他为0的,不计算损失。

换言之:

  1. 对于预测输出维度为(batch, C)的情况,CrossEntropyLoss在计算损失时只会对应gt类别的那个值进行计算,其他类别的预测输出不参与计算。
  2. PyTorch的实现中,这个过程被封装在了nn.functional.cross_entropy函数中。
  3. 具体地说,对于batch内的每个样本cross_entropy会根据gt标签的值在预测输出的张量中找到对应的位置,然后计算该位置上的预测输出和对应的对数概率,最终将这些计算结果取平均得到整个batch的损失值。

举个例子:

  • 对应一个minibatch=8, classes=15的多分类任务,传入nn.CrossEntropyLoss(pd, gt),计算损失。
  • 其中:pd的shape是[8,15],gt的shape是[8],也就是一个minibatch的各个图对应的类别,如果传入的gt是one-hot形式,那么one-hot形式的shape是[8,15],需要经过torch.max(label, 1)[1]处理,意思是说在label的维度为1上做max操作。
  • 返回tensors = (max, max_indices),取max_indices,得到对应的类别数。

如下案例:

>>> import torch
>>> a = torch.eye(8,15)
>>> a
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
>>> a.shape
torch.Size([8, 15])
>>> gt = torch.max(a, 1)[1]
>>> gt
tensor([0, 1, 2, 3, 4, 5, 6, 7])
>>> gt.shape
torch.Size([8])

再把前面手算的loss案例,拿过来计算下,如下:

>>> pd=torch.tensor([[4.0,8.0,3.0]])
>>> gt=torch.tensor([1])
>>> loss = c(pd, gt.long())
>>> loss
tensor(0.0247)>>> gt=torch.tensor([2])
>>> loss = c(pd, gt.long())
>>> loss
tensor(5.0247)

可以看到,

  • gt为1的,对应的loss是0.0247;
  • gt为2的,对应的loss是5.0247;
  • 与我们前面手算的一致,也符合只对one hot中为1的部分计算损失,其他不管。

对于 (minibatch,C,d1,d2,···,dk)的情况,其中k≥1的 K-dimensional case,同样适用。这里就是UNet分割网络计算损失时候的重点。前面说那么多也就是为了引出这个部分。案例如下:

>>> gt=torch.rand(8, 24, 24)
>>> pd=torch.rand(8, 5, 24, 24)
>>> loss = c(pd, gt.long())
>>> loss
tensor(1.6401)
>>> 

换言之:

  1. 对于预测输出pd的维度为(minibatch,C,d1,d2)的情况,输入的gt的shape应该为(minibatch,d1,d2),即每个样本的标签是一个d1*d2的二维矩阵
  2. 其中,每个像素的值表示对应位置的像素属于哪个类别,通常使用整数表示类别的编号。
  3. 例如,如果有3个类别,那么像素可能被标记为0、1或2。在计算交叉熵损失时,会将预测输出pd和对应的gt进行比较,计算每个像素的分类损失,然后对所有像素的损失求平均得到最终的损失值。
  4. 对于输入的预测张量,通常会在C维度上进行softmax操作,使得每个通道(类别)的输出值都在[0,1]范围内,并且所有通道的输出值之和为1
  5. 这样做的目的是将预测结果转换成概率分布,方便计算交叉熵损失。
  6. 在PyTorch中,torch.nn.CrossEntropyLoss()函数会自动将输入进行softmax操作。

二、Balanced Cross Entropy

Focal Loss for Dense Object Detection里面,引入了Balanced Cross Entropy的描述。在上文中pytorchCross Entropy里面的weight就是自己定义的类别权重,这是一个超参数。
0

其中:

  1. for class 1 ,α 属于 [0, 1]
  2. for class -1,weight=1- α
  3. α 设定是class 数量的倒数,数量越多,α 越小

三、Dice loss

Dice系数中的"Dice"实际上是一位科学家名字的缩写,其全名是Sørensen–Dice coefficient,常被称为Dice similarity coefficient或者F1 score。它由植物学家Thorvald SørensenLee Raymond Dice独立研制,分别于 1948 年和 1945 年发表。

Dice系数是一种常见的相似度计算方法,主要用于计算两个集合的相似度。在 Dice Loss 中,用 Dice 系数来计算预测结果和真实标签的相似度,因此得名 Dice Loss

dice coefficient定义如下:
1

如果看作是对像素点类别的分类任务,也可以写成:
2

于是,dice loss就可以表示为:
3

Dice系数的中文名称为“Dice相似系数”或“Dice相似度”,因此 Dice Loss 也可以称为“Dice相似度损失”或“Dice相似系数损失”。

对于输出只有一个类别的分割任务,输出的channel=1。在计算dice loss的时候,需要先使用sigmoid将预测结果,归一化到0-1的范围。具体的代码如下所示:

def dice_loss(pred, target):smooth = 1.iflat = pred.contiguous().view(-1)tflat = target.contiguous().view(-1)intersection = (iflat * tflat).sum()return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))# example usage:
batch_size, num_classes, height, width = 2, 1, 256, 256
pred = torch.randn(batch_size, num_classes, height, width)
# apply sigmoid to predicted output
pred = torch.sigmoid(pred)
print(pred)
target = torch.randint(0, 2, size=(batch_size, height, width)).float()
loss = dice_loss(pred, target)
print(loss)

对于多类别的loss,定义如下:

def dice_loss_2(pred, target, weight=None, num_classes=5):smooth = 1.# 各个类别分别计算损失dice = torch.zeros(num_classes, dtype=torch.float32, device=pred.device)for c in range(num_classes):# 取出对应类的pred和targetp_c = pred[:, c].contiguous().view(-1)print(p_c)# target 标记是0、1、2、3、4,表示的是背景和各个类,下面操作类似于one-hott_c = (target == c).float().view(-1)intersection = torch.sum(p_c * t_c)union = torch.sum(p_c) + torch.sum(t_c)dice[c] = (2. * intersection + smooth) / (union + smooth)# apply class weightsif weight is not None:dice = dice * weight.view(1, -1)return 1 - dice.mean()batch_size, num_classes, height, width = 2, 5, 32, 32
pred = torch.randn(batch_size, num_classes, height, width)
pred = torch.softmax(pred, dim=1)   # 在num_classes维度上,进行softmax操作
target = torch.randint(0, num_classes, (batch_size, height, width))print(target)
loss = dice_loss_2(pred, target)
print(loss)

打印targetp_c,可以看到下面的结果,其中红框区域内的和是1,是pred = torch.softmax(pred, dim=1) 的操作结果:

2

这个时候我就产生了一个疑惑:在沿着num_classes的维度上,进行softmax操作后,带入dice loss公式时候,是否需要将最大值置为1,其他置为0后再计算吗?还是直接0-1范围内的数值,直接和target的0 or 1进行计算呢?像上面的函数定义一样。为什么?

chatGPT给出了答案,如下:

在这个函数中,已经对预测值进行了softmax操作,使得每个像素点在每个类别上的预测值都在0到1之间,表示该像素点属于每个类别的概率。因此,在计算dice loss时,无需将最大值置为1,其他置为0

该函数中,对于每个类别c,通过计算交集和并集,得到该类别的dice系数。然后将所有类别的dice系数取平均,并用1减去平均值,得到最终的dice loss

因此,对于每个像素点,其在所有类别上的预测值都会对最终的dice loss产生影响,而不仅仅是最大值

我继续追问:意思是说,如果取最大值置为0,其他置为0,那么就会导致只会在为1的维度上进行计算loss,而其他的损失未参与计算,是这样一个意思吗?

chatGPT回答:是的,如果将最大值置为1,其他置为0,那么只有在最大值对应的类别上计算的dice系数才会对最终的dice loss产生影响,而其他类别的dice系数则不会参与计算。这样会导致损失函数无法充分利用所有类别的信息,可能会影响模型的训练效果。因此,在计算dice loss时,应该将每个像素点在所有类别上的预测值都考虑进去。

更多内容,可以参考这里:图像分割必备知识点 | Dice损失 理论+代码

四、Focal Loss

在分类任务中,容易分类的negative构成了大部分的损失,并主导了梯度。但是这部分并不是最重要的,并且是比较容易区分的。Balanced Cross Entropy中的α balances the importance of positive/negative examples, it does not differentiate between easy/hard examples

进而,我们提出了一种loss function去降低weight of easy example,让训练集中到hard negative

Focal loss是最初由何恺明在论文:Focal Loss for Dense Object Detection论文中提出来的,最初用于图像领域解决数据不平衡造成的模型性能问题。

在论文中:

focal loss
翻译过来就是:

直观地说,这个比例因子可以在训练过程中,自动降低简单示例的权重,并迅速将模型集中在困难示例上。实验表明,我们提出的Focal Loss使我们能够训练一个高精度的单阶段检测器,其性能明显优于使用抽样启发式或硬示例挖掘训练的替代方案,这是先前用于训练单阶段检测器的最先进技术。最后,我们注意到focal loss的确切形式并不重要,并且我们展示了其他实例可以实现类似的结果

Focal loss是基于二分类交叉熵CE的。它是一个动态缩放的交叉熵损失,通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦(focal)在那些难区分的样本(有可能是正样本,也有可能是负样本,但都是对训练网络有帮助的样本)。

12

论文中提到,我们注意到焦损的两个特性(We note two properties of the focal loss.)

  1. 对于预测结果,与真实的标注,会存在两种情况:
    • 当样本被误分类时,pt较小,趋于0,调制因子接近1,权重不变,loss不受影响;
    • 当样本被正确分类时,pt趋于1,调制因子接近0,对于分类较好的,降低权重。(When an example is misclassified and pt is small, the modulating factor is near 1 and the loss is unaffected.As pt --> 1, the factor goes to 0 and the loss for well-classified examples is down-weighted)
    • 直观讲,就是计算损失时,降低正确分类的loss在总loss中的权重,预测差的保持loss,这样变相的提高了错误分类样本在loss中的权重。
  2. 参数γ降低了简单样本的降低权重操作( The focusing parameter γ smoothly adjusts the rate at which easy examples are down weighted. )

γ = 0时,Focal Loss等于CE,随着γ的增大,调制因子的作用也随之增大(When γ = 0, FL is equivalent to CE, and as γ is increased the effect of the modulating factor is likewise increased )

对这块的解释,网上资料太多了,关键就在于你该如何理解。我搜罗了一些好理解的,发散开的,建议参考这里:何恺明大神的「Focal Loss」,如何更好地理解?

他从硬截断,去除,不参与total loss的计算:正样本的预测值大于 0.5 的,或者负样本的预测值小于 0.5 的,我都不更新了,把注意力集中在预测不准的那些样本,当然这个阈值可以调整。

导致模型只知道更新难的,错误的样本损失,不知道保持正确的分类。等等,相信上面的文章能给你启发。但是最后还是建议去看看原论文,介绍的很详细。

五、总结

除了上面介绍的几个分类常用到的损失函数,还一直会有更加优秀的loss出现。本篇只是开了一个头,并且上述介绍的损失函数,在各个深度学习框架内基本都包含了这些函数,也无需我们自己手写。

在后续如果有更优秀的loss出现,我也会把它归类到这里。喜欢就点赞收藏不迷路。

在最后的最后,如果你也在学习新的内容,一定要去翻看原论文。我一般把参考的论文链接,都放到了文章里面。在原论文里面,因为是第一次出现,所以一般描述的都是比较的详尽,包括:

  • 为什么会有这个概念的出现?
  • 他解决了之前存在的什么问题?
  • 有什么优势和具体如何实现的?

都非常的详尽,是一个学习的第一手资料。希望对你有用,其余其中不正确和不详尽的地方,欢迎评论区留言,一起讨论,我也会学习后,分享补充进来。


http://www.ppmy.cn/news/77984.html

相关文章

对CommonJS、AMD、CMD、ES Module的理解

CommonJS 常用于:服务器端,node,webpack 特点:同步/运行时加载,磁盘读取速度快 语法: // 1. 导出:通过module.exports或exports来暴露模块 module.exports { attr1, attr2 } ex…

npm init和npm create、npm create vite什么意思

npm init 和npm create 之前在用npm时,一直都是npm init来初始化一个项目并生成package.json文件。 但是,今天在看vite的官方文档时,vite上说创建一个vite项目的命令是 npm create vitelatest之前一直没有用过npm create这个命令&#xff…

亚马逊关联视频:如何成为影响者视频行业的重要一环?

亚马逊作为全球最大的电商平台之一,一直在探索如何让更多的消费者通过视频了解和购买自己的产品。而随着社交媒体和视频内容的兴起,越来越多的品牌和卖家开始将亚马逊关联视频作为营销和推广的重要手段,同时也为消费者提供了更加生动、直观的…

pdf如何删除其中一页?不妨试试这些办法

PDF格式是一种非常常见的文档格式,它可以在各种系统和设备上使用,而且无论在哪里打开,都可以保持格式的一致性。有时候,我们需要删除PDF文档中的一页,无论是为了更改文档的结构,还是为了删除错误的信息。在…

MATLAB PCA及其实现方法

PCA(Principal Component Analysis)是一种常用的数据降维方法,它可以将高维数据转换为低维数据,同时保留原始数据中的主要信息。在实际应用中,我们通常使用matlab来实现PCA算法。 本文将介绍matlab中的PCA实现方法&am…

用写代码的方式画图-试下PlantUML吧 | 京东云技术团队

1 序言 所谓一图胜千言,大家平日在工作中编写文档时,往往都需要画各种图来表达中心思想,比如流程图、时序图、UML 图,很多人选择使用 Axure 、PrecessOn、Diagrams(darw.io)、XMind、Visio、yEd、Lucidcha…

Python 近场通信:使用 NFC (Near Field Communication) 和 nfcpy

目录 NFC 的基本原理 nfcpy 模块介绍 Python NFC 读取例子 结语 近场通信 (NFC) 是一种在物理距离 4 厘米以内的设备之间进行简单安全的双向交互的技术。Python 提供了一种实现 NFC 的模块,名为 nfcpy。本文将对 NFC 的原理进行简要介绍,并通过 Pyth…

汇聚支付APP+技术方案介绍

一、时序图 商户使用汇聚支付的 APP方案,需要两个步骤: 步骤一: 请求汇聚支付的支付接口,https://www.joinpay.com/trade/uniPayApi.action 获取返回的关键参数 rc_Result。 步骤二: 商户 APP 按照微信的 SDK 规范使…