《深度学习》PyTorch 常用损失函数原理、用法解析

server/2024/9/23 9:37:01/

目录

一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

        2)流程

        3)用法示例

2、L1Loss(L1损失/平均绝对误差)

        1)原理

        2)用法示例

3、NLLLoss(负对数似然损失)

        1)原理

        2)用法示例

4、 MSELoss(均方误差损失)

        1)定义

        2)用法示例

5. BCELoss(二元交叉熵损失)

        1)定义

        2)用法示例

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

        2、nn.MSELoss:均方误差损失函数

        3、nn.L1Loss:平均绝对误差损失函数

        4、nn.BCELoss:二元交叉熵损失函数

        5、nn.NLLLoss:负对数似然损失函数


一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

                交叉熵损失是一种常用于分类问题的损失函数,它衡量的是模型输出的概率分布与真实标签分布之间的差异

                在多分类问题中,模型会输出每个类别的预测概率。交叉熵损失通过计算真实标签对应类别的负对数概率评估模型的性能。在实际应用中,nn.CrossEntropyLoss内部会对logits(即未经softmax的原始输出)应用softmax函数,将其转换为概率分布,然后计算交叉熵。

                例如:

                        假设有一个多类别分类任务,共有C个类别。对于每个样本,模型会输出一个包含C个元素的向量,其中每个元素表示该样本属于对应类别的概率。而真实标签是一个C维的向量,其中只有一个元素为1,其余元素均为0,表示样本的真实类别。

        2)流程

                首先,将模型输出的向量通过softmax函数进行归一化,将原始的概率值转换为概率分布。即对模型输出的每个元素进行指数运算,然后对所有元素求和,最后将每个元素除以总和,得到归一化后的概率分布。

                然后,将归一化后的概率分布与真实标签进行比较,计算两者之间的差异。交叉熵损失函数的计算公式为: -sum(y * log(p))  ,其中y是真实标签的概率分布,p是模型输出的归一化后的概率分布。该公式表示真实标签的概率分布与模型输出的归一化后的概率分布之间的交叉熵。

                最后,将每个样本的交叉熵损失值进行求和或平均,得到整个批次的损失值。

       

        3)用法示例
import torch  
import torch.nn as nn  # 假设有一个模型输出的logits和一个真实的标签  
logits = torch.randn(10, 5, requires_grad=True)  # 10个样本,5个类别  
labels = torch.randint(0, 5, (10,))  # 真实标签,每个样本对应一个类别索引  # 创建CrossEntropyLoss实例  
loss_fn = nn.CrossEntropyLoss()  # 计算损失  
loss = loss_fn(logits, labels)  # 反向传播  
loss.backward()

2、L1Loss(L1损失/平均绝对误差)

        1)原理

                L1损失,也称为平均绝对误差(MAE),计算的是预测值与真实值之差绝对值平均值

                L1损失对异常值(即远离平均值的点)的敏感度较低,因为它通过绝对值来度量误差,而绝对值函数在零点附近是线性的。

       

        2)用法示例
loss_fn = nn.L1Loss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

3、NLLLoss(负对数似然损失)

        1)原理

                负对数似然损失(NLLLoss)通常与log_softmax一起使用,用于多分类问题。它计算的是目标类别负对数概率

                NLLLoss期望的输入是对数概率(即已经通过log_softmax处理过的输出),然后计算目标类别的负对数概率。

        2)用法示例
# 假设已经计算了logits  
logits = torch.randn(3, 5, requires_grad=True)  # 应用log_softmax获取对数概率(在PyTorch中,通常直接使用CrossEntropyLoss)  
log_probs = torch.log_softmax(logits, dim=1)  # 创建NLLLoss实例  
loss_fn = nn.NLLLoss()  # 真实标签  
labels = torch.tensor([1, 0, 4], dtype=torch.long)  # 计算损失  
loss = loss_fn(log_probs, labels)  # 反向传播  
loss.backward()

                在实际应用中,直接使用CrossEntropyLoss更为常见,因为它内部集成了softmax和NLLLoss的计算。

4、 MSELoss(均方误差损失)

        1)定义

                均方误差损失(MSE)计算的是预测值与真实值之差的平方的平均值

                MSE通过平方误差来放大较大的误差,从而给予模型更大的惩罚。它是回归问题中最常用的损失函数之一。

        2)用法示例
loss_fn = nn.MSELoss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

5.BCELoss(二元交叉熵损失)

        1)定义

                二元交叉熵损失(BCE)用于二分类问题,计算的是预测概率与真实标签(0或1)之间的交叉熵

                BCE通过计算真实标签对应类别的负对数概率来评估模型的性能。它适用于输出概率的模型,但并不要求输入必须经过sigmoid函数(尽管在实践中很常见)。

        2)用法示例
loss_fn = nn.BCELoss()  # 假设预测值已经通过sigmoid函数(虽然不是必需的)  
predictions = torch.sigmoid(torch.randn(3, requires_grad=True))  # 真实标签  
targets = torch.empty(3).random_(2).float()  # 生成0或1的随机值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

                主要用于多分类问题。它将模型的输出(logits)与真实标签进行比较,并计算损失。

        2、nn.MSELoss:均方误差损失函数

                用于回归问题。它计算模型输出与真实标签之间的差异的平方,并返回平均值。

        3、nn.L1Loss:平均绝对误差损失函数

                也称为L1损失。类似于MSELoss,但是它计算模型输出与真实标签之间的差异的绝对值,并返回平均值。

        4、nn.BCELoss:二元交叉熵损失函数

                用于二分类问题。它计算二分类问题中的模型输出与真实标签之间的差异,并返回损失。

        5、nn.NLLLoss:负对数似然损失函数

                主要用于多分类问题。它首先应用log_softmax函数(log_softmax(x) = log(softmax(x)))将模型输出转化为对数概率,然后计算模型输出与真实标签之间的差异。


http://www.ppmy.cn/server/119059.html

相关文章

萌宠宜家商城系统

摘 要 随着现在经济的不断发展和信息技术性日益完善和优化,传统式数据信息的管理升级成手机软件存放、梳理和数据信息集中统一处理的管理方式。本萌宠物宜家商城系统软件起源于这个环境中,能够帮助管理者在短期内进行庞大数据信息。使用这个专业软件能够…

Android WebView H5 Hybrid 混和开发

对于故乡,我忽然有了新的理解:人的故乡,并不止于一块特定的土地,而是一种辽阔无比的心情,不受空间和时间的限制;这心情一经唤起,就是你已经回到了故乡。——《记忆与印象》 前言 移动互联网发展…

胤娲科技:解锁AI奥秘——产品经理的智能进化之旅

当AI不再是遥不可及的科幻 想象一下,你走进一家未来感十足的咖啡厅,无需言语,智能咖啡机就能根据你的偏好调制出一杯完美的拿铁; 打开手机,AI助手不仅提醒你今天有雨,还贴心推荐了最适合雨中漫步的音乐列表…

xml重点笔记(尚学堂 3h)

XML:可扩展标记语言 主要内容(了解即可) 1.XML介绍 2.DTD 3.XSD 4.DOM解析 6.SAX解析 学习目标 一. XML介绍 1.简介 XML(Extensible Markup Language) 可扩展标记语言,严格区分大小写 2.XML和HTML XML是用来传输和存储数据的。 XML多用在框架的配置文件…

linux下使用Mail命令发送邮件的配置、快速实现以及sed命令的一些补充:行结合模式匹配取内容及sed命令显示配置文件中的有效内容

一、linux下使用Mail命令发送邮件的配置及快速实现 之前在服务器上增加了一些日志统计shell脚本并且每周进行一次日志分析统计自动在周一早上发到我的邮箱,最近服务器进行了迁移收缩,又得做点重复的事情,首先是让服务器支持邮件发送。 1&am…

MATLAB中的代码覆盖测试:深入指南与实践应用

在软件测试领域,代码覆盖测试是一种重要的技术,用于评估测试用例的完整性和有效性。在MATLAB环境中,代码覆盖测试可以帮助开发者确保他们的代码在各种条件下都能正常工作,并且能够发现可能被忽视的错误。本文将详细介绍如何在MATL…

rust + bevy 实现小游戏 打包成wasm放在浏览器环境运行

游戏界面 代码地址 github WASM运行 rustup target install wasm32-unknown-unknown cargo install wasm-server-runner cargo run --target wasm32-unknown-unknowncargo install wasm-bindgen-cli cargo build --release --target wasm32-unknown-unknown wasm-bindgen --…

基于Springboot+vue的音乐网站

随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了音乐网站的开发全过程。通过分析音乐网站管理的不足,创建了一个计算机管理音乐网站的方案。文章介绍了音乐网站的系统分析部分,包括可行性分析…