PyTorch 的 nn.NLLLoss:负对数似然损失全解析

news/2025/3/6 3:58:41/

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

在 PyTorch 的损失函数家族中,nn.NLLLoss(Negative Log Likelihood Loss,负对数似然损失)是一个不太起眼但非常重要的成员。它经常跟 LogSoftmax 搭配出现,尤其在分类任务中扮演关键角色。今天我们就来聊聊 nn.NLLLoss 的数学原理、使用方法,以及它适用的场景,带你彻底搞懂这个损失函数。

1. 什么是负对数似然损失?

先从名字拆解:

  • 似然(Likelihood):在统计学中,似然表示“给定模型参数时,观察到数据的概率”。对数似然(Log Likelihood)是它的对数形式,常用于简化计算。
  • 负对数似然(Negative Log Likelihood, NLL):把对数似然取负数,作为损失函数,目标是最小化它。

在机器学习中,负对数似然损失通常用来衡量模型预测的概率分布与真实标签的差距,尤其是在分类任务中。

数学公式

假设我们有一个多分类任务,有 ( C C C ) 个类别。对于一个样本:

  • ( y ^ \hat{y} y^ ) 是模型输出的概率分布,比如经过 Softmax 或 LogSoftmax 处理后的结果。
  • ( y y y ) 是真实类别,用索引表示(比如 2 表示第 2 类)。

nn.NLLLoss 的公式是:

NLL = − 1 N ∑ i = 1 N log ⁡ ( y ^ i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(\hat{y}_{i, y_i}) NLL=N1i=1Nlog(y^i,yi)

  • ( N N N ):样本数量(batch size)。
  • ( y i y_i yi ):第 ( i i i ) 个样本的真实类别索引。
  • ( y ^ i , y i \hat{y}_{i, y_i} y^i,yi ):第 ( i i i ) 个样本在真实类别 ( y i y_i yi ) 上的预测概率(对数值)。

简单来说,nn.NLLLoss 取预测概率的对数(已经由 LogSoftmax 计算好),然后取负号,只关心正确类别的概率值。

2. 为什么搭配 LogSoftmax

你可能会注意到,nn.NLLLoss 的文档里总是提到“通常与 LogSoftmax 搭配使用”。这是为什么?

  • 模型输出:神经网络的最后一层通常输出未归一化的 logits(比如 [1.0, 2.0, 0.5]),而不是概率。
  • Softmax:将 logits 转为概率分布,比如 [0.2, 0.5, 0.3],满足 ( ∑ y ^ = 1 \sum \hat{y} = 1 y^=1)。公式是:
    y ^ j = e z j ∑ k = 1 C e z k \hat{y}_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} y^j=k=1Cezkezj
  • LogSoftmax:在 Softmax 基础上取对数,输出的是对数概率,比如 [-1.6, -0.7, -1.2]。公式是:
    log ⁡ ( y ^ j ) = z j − log ⁡ ( ∑ k = 1 C e z k ) \log(\hat{y}_j) = z_j - \log(\sum_{k=1}^{C} e^{z_k}) log(y^j)=zjlog(k=1Cezk)

nn.NLLLoss 要求输入是对数概率(log probabilities),而不是原始概率。所以:

  • 如果你直接给它 Softmax 后的概率,会出错,因为它期待的是 ( log ⁡ ( y ^ ) \log(\hat{y}) log(y^))。
  • LogSoftmax 处理后,输入正好符合要求,计算时直接取负号即可。
3. 代码使用示例

我们来看一个简单的例子,展示 nn.NLLLossLogSoftmax 的搭配:

python">import torch
import torch.nn as nn# 假设一个 3 分类任务,batch_size = 2
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]])  # 原始 logits
target = torch.tensor([1, 2])  # 真实类别索引,0~2# 定义 LogSoftmax 和 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)  # dim=1 表示在类别维度上归一化
loss_fn = nn.NLLLoss()# 计算损失
log_probs = log_softmax(logits)  # 先转为对数概率
loss = loss_fn(log_probs, target)
print("NLL Loss:", loss.item())

运行过程

  1. logits[batch_size, num_classes] 的张量,表示每个样本在每个类别上的得分。
  2. nn.LogSoftmax 把 logits 转为对数概率,比如 [[-1.9, -0.9, -2.4], [-2.3, -1.9, -0.4]]
  3. nn.NLLLoss 提取每个样本在真实类别上的对数概率(比如第一个样本取 -0.9,第二个取 -0.4),取负并平均。

输出可能是 1.15,具体值取决于输入。

4. 与 nn.CrossEntropyLoss 的关系

你可能听说过 nn.CrossEntropyLoss,它也很常见。实际上:

  • nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss
    PyTorch 把这两步合二为一,直接接受 logits 作为输入,内部自动完成 LogSoftmax 和 NLL 计算。具体过程可以参考笔者的另一篇博客:Pytorch为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?

代码对比:

python"># 用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())  # 与上面结果相同
  • 区别
    • nn.NLLLoss:输入是对数概率,需手动加 LogSoftmax
    • nn.CrossEntropyLoss:输入是 logits,自动处理。
5. 使用场景

nn.NLLLoss 适用于以下场景:

  • 多分类任务:比如图像分类(CIFAR-10 的 10 类)、文本分类。
  • 需要分离 Softmax 的情况
    • 你想在模型里显式控制 LogSoftmax 的位置,而不是交给损失函数。
    • 调试时单独检查对数概率的值。
  • 概率输出的模型:如果你的模型已经输出对数概率(比如某些预训练模型),直接用 nn.NLLLoss 更高效。

典型例子

  • 一个简单的 CNN 分类器:
    python">class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 16, 3)self.fc = nn.Linear(16 * 26 * 26, 10)  # 假设 28x28 输入self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, x):x = self.conv(x)x = x.view(x.size(0), -1)x = self.fc(x)return self.log_softmax(x)model = SimpleCNN()
    loss_fn = nn.NLLLoss()
    
    这里模型输出对数概率,搭配 nn.NLLLoss 计算损失。
6. 注意事项
  • 输入形状
    • 输入:[batch_size, num_classes](对数概率)。
    • 目标:[batch_size](类别索引)。
  • 目标类型:必须是整数(long 类型),不能是 one-hot 或浮点数。
  • 数值稳定性LogSoftmax 比单独的 Softmax + log 更稳定,因为它避免了溢出问题。
7. 小结:nn.NLLLoss 的核心
  • 数学原理:计算正确类别对数概率的负值,最小化它等价于最大化似然。
  • 使用方式:搭配 LogSoftmax,输入对数概率,输出标量损失。
  • 场景:多分类任务,尤其是需要显式控制概率计算时。
  • CrossEntropyLoss 的关系:前者是后者的组成部分,功能更模块化。

nn.NLLLoss 就像一个“半成品”,需要你自己搭配 LogSoftmax,但这也给了你更多灵活性。相比直接用 nn.CrossEntropyLoss,它更适合喜欢拆解步骤或调试模型的开发者。

8. 调试小技巧
  • 检查输入:打印 log_probs 确保是对数概率(负值)。
  • 验证目标:确保 target 是整数,且范围在 [0, num_classes-1]
  • 对比结果:用 nn.CrossEntropyLoss 验证是否一致。

希望这篇博客让你对 nn.NLLLoss 有了全面认识!

后记

2025年2月28日18点59分于上海,在Grok3大模型辅助下完成。


http://www.ppmy.cn/news/1576974.html

相关文章

读写分离架构下的一致性挑战

读写分离架构下的一致性挑战 什么是读写分离架构读写分离架构的一致性挑战主从复制延迟事务不一致 主从切换导致的数据丢失跨表/跨库操作的一致性问题缓存与数据库的一致性问题查询路由策略不当导致的问题全局二级索引的一致性问题历史查询与实时数据的一致性分布式锁与读写分离…

Transformer 代码剖析9 - 解码器模块Decoder (pytorch实现)

一、模块架构全景图 1.1 核心功能定位 Transformer解码器是序列生成任务的核心组件,负责根据编码器输出和已生成序列预测下一个目标符号。其独特的三级注意力机制架构使其在机器翻译、文本生成等任务中表现出色。下面是解码器在Transformer架构中的定位示意图&…

kettle插件-git/svn版本管理插件

场景:大家都知道我们平时使用spoon客户端的时候时无法直接使用git的,给我们团队协作带来了一些小问题,需要我们本机单独安装git客户端进行手动上传trans或者job。 我们团队成员倪老师开发了一款kettle的git插件,帮我们解决了这个…

浅浅初识AI、AI大模型、AGI

前记:这里只是简单了解,后面有时间会专门来扩展和深入。 当前,人工智能(AI)及其细分领域(如AI算法工程师、自然语言处理NLP、通用人工智能AGI)的就业前景呈现高速增长态势,市场需求…

【Flink银行反欺诈系统设计方案】2.风控规则表设计与Flink CEP结合

Flink CEP与风控规则表结合的银行反欺诈系统 1. 实现思路 规则加载: 使用Flink的JDBC Source定期从risk_rules表中加载规则。 将规则广播到所有Flink任务中。 动态模式构建: 根据规则表中的条件动态构建Flink CEP的模式。 将交易数据流与规则广播…

C语言机试编程题

编写版本:vc2022 目录 1.求最大/小值 2.求一个三位数abc,使a的阶乘b的阶乘c的阶乘abc 3.求2/1,3/2,5/3,8/5,13/8,21/13,的前20项和 4.求阶乘 5.求10-1000之间所有数字之和为5的…

Github 2025-03-04 Python开源项目日报 Top10

根据Github Trendings的统计,今日(2025-03-04统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目10Svelte项目1JavaScript项目1 系统设计指南 创建周期:2507 天开发语言:P…

自然语言处理:朴素贝叶斯

介绍 大家好,博主又来和大家分享自然语言处理领域的知识了。按照博主的分享规划,本次分享的核心主题本应是自然语言处理中的文本分类。然而,在对分享内容进行细致梳理时,我察觉到其中包含几个至关重要的知识点,即朴素…