《深度学习》PyTorch 常用损失函数原理、用法解析

news/2024/9/23 1:09:50/

目录

一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

        2)流程

        3)用法示例

2、L1Loss(L1损失/平均绝对误差)

        1)原理

        2)用法示例

3、NLLLoss(负对数似然损失)

        1)原理

        2)用法示例

4、 MSELoss(均方误差损失)

        1)定义

        2)用法示例

5. BCELoss(二元交叉熵损失)

        1)定义

        2)用法示例

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

        2、nn.MSELoss:均方误差损失函数

        3、nn.L1Loss:平均绝对误差损失函数

        4、nn.BCELoss:二元交叉熵损失函数

        5、nn.NLLLoss:负对数似然损失函数


一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

                交叉熵损失是一种常用于分类问题的损失函数,它衡量的是模型输出的概率分布与真实标签分布之间的差异

                在多分类问题中,模型会输出每个类别的预测概率。交叉熵损失通过计算真实标签对应类别的负对数概率评估模型的性能。在实际应用中,nn.CrossEntropyLoss内部会对logits(即未经softmax的原始输出)应用softmax函数,将其转换为概率分布,然后计算交叉熵。

                例如:

                        假设有一个多类别分类任务,共有C个类别。对于每个样本,模型会输出一个包含C个元素的向量,其中每个元素表示该样本属于对应类别的概率。而真实标签是一个C维的向量,其中只有一个元素为1,其余元素均为0,表示样本的真实类别。

        2)流程

                首先,将模型输出的向量通过softmax函数进行归一化,将原始的概率值转换为概率分布。即对模型输出的每个元素进行指数运算,然后对所有元素求和,最后将每个元素除以总和,得到归一化后的概率分布。

                然后,将归一化后的概率分布与真实标签进行比较,计算两者之间的差异。交叉熵损失函数的计算公式为: -sum(y * log(p))  ,其中y是真实标签的概率分布,p是模型输出的归一化后的概率分布。该公式表示真实标签的概率分布与模型输出的归一化后的概率分布之间的交叉熵。

                最后,将每个样本的交叉熵损失值进行求和或平均,得到整个批次的损失值。

       

        3)用法示例
import torch  
import torch.nn as nn  # 假设有一个模型输出的logits和一个真实的标签  
logits = torch.randn(10, 5, requires_grad=True)  # 10个样本,5个类别  
labels = torch.randint(0, 5, (10,))  # 真实标签,每个样本对应一个类别索引  # 创建CrossEntropyLoss实例  
loss_fn = nn.CrossEntropyLoss()  # 计算损失  
loss = loss_fn(logits, labels)  # 反向传播  
loss.backward()

2、L1Loss(L1损失/平均绝对误差)

        1)原理

                L1损失,也称为平均绝对误差(MAE),计算的是预测值与真实值之差绝对值平均值

                L1损失对异常值(即远离平均值的点)的敏感度较低,因为它通过绝对值来度量误差,而绝对值函数在零点附近是线性的。

       

        2)用法示例
loss_fn = nn.L1Loss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

3、NLLLoss(负对数似然损失)

        1)原理

                负对数似然损失(NLLLoss)通常与log_softmax一起使用,用于多分类问题。它计算的是目标类别负对数概率

                NLLLoss期望的输入是对数概率(即已经通过log_softmax处理过的输出),然后计算目标类别的负对数概率。

        2)用法示例
# 假设已经计算了logits  
logits = torch.randn(3, 5, requires_grad=True)  # 应用log_softmax获取对数概率(在PyTorch中,通常直接使用CrossEntropyLoss)  
log_probs = torch.log_softmax(logits, dim=1)  # 创建NLLLoss实例  
loss_fn = nn.NLLLoss()  # 真实标签  
labels = torch.tensor([1, 0, 4], dtype=torch.long)  # 计算损失  
loss = loss_fn(log_probs, labels)  # 反向传播  
loss.backward()

                在实际应用中,直接使用CrossEntropyLoss更为常见,因为它内部集成了softmax和NLLLoss的计算。

4、 MSELoss(均方误差损失)

        1)定义

                均方误差损失(MSE)计算的是预测值与真实值之差的平方的平均值

                MSE通过平方误差来放大较大的误差,从而给予模型更大的惩罚。它是回归问题中最常用的损失函数之一。

        2)用法示例
loss_fn = nn.MSELoss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

5.BCELoss(二元交叉熵损失)

        1)定义

                二元交叉熵损失(BCE)用于二分类问题,计算的是预测概率与真实标签(0或1)之间的交叉熵

                BCE通过计算真实标签对应类别的负对数概率来评估模型的性能。它适用于输出概率的模型,但并不要求输入必须经过sigmoid函数(尽管在实践中很常见)。

        2)用法示例
loss_fn = nn.BCELoss()  # 假设预测值已经通过sigmoid函数(虽然不是必需的)  
predictions = torch.sigmoid(torch.randn(3, requires_grad=True))  # 真实标签  
targets = torch.empty(3).random_(2).float()  # 生成0或1的随机值  # 计算损失  
loss = loss_fn(predictions, targets)  # 反向传播  
loss.backward()

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

                主要用于多分类问题。它将模型的输出(logits)与真实标签进行比较,并计算损失。

        2、nn.MSELoss:均方误差损失函数

                用于回归问题。它计算模型输出与真实标签之间的差异的平方,并返回平均值。

        3、nn.L1Loss:平均绝对误差损失函数

                也称为L1损失。类似于MSELoss,但是它计算模型输出与真实标签之间的差异的绝对值,并返回平均值。

        4、nn.BCELoss:二元交叉熵损失函数

                用于二分类问题。它计算二分类问题中的模型输出与真实标签之间的差异,并返回损失。

        5、nn.NLLLoss:负对数似然损失函数

                主要用于多分类问题。它首先应用log_softmax函数(log_softmax(x) = log(softmax(x)))将模型输出转化为对数概率,然后计算模型输出与真实标签之间的差异。


http://www.ppmy.cn/news/1529090.html

相关文章

【STM32系统】基于STM32设计的DAC输出电压与ADC检测电压系统(简易万用表,检测电压电流)——文末工程资料下载

基于STM32设计的DAC输出电压与ADC检测电压系统(简易万用表,检测电压电流) 演示视频: 基于STM32设计的DAC输出电压与ADC检测电压系统(简易万用表,检测电压电流) 前言:本项目实现对STM32的DAC和ADC的程序设计与硬件电路连接实现STM32内部DAC输出电压,并且ADC可以采集电压…

dedecms(四种webshell姿势)、aspcms webshell漏洞复现

一、aspcms webshell 1、登陆后台&#xff0c;在扩展功能的幻灯片设置模块&#xff0c;点击保存进行抓包查看 2、在slideTextStatus写入asp一句话木马 1%25><%25Eval(Request(chr(65)))%25><%25 密码是a&#xff0c;放行&#xff0c;修改成功 3、使用菜刀工具连…

AWS 消息通知系统 SNS

AWS 消息通知系统 SNS 引言什么是 AWS SNSSNS 的工作原理SNS 的主要应用场景示例&#xff1a;创建 SNS 主题和订阅使用 AWS 管理控制台使用 AWS CLI使用 AWS SDK (Python Boto3) 示例 CloudWatch 如何通过 SNS 发送告警通知 引言 《AWS 监控和管理服务 CloudWatch》有介绍 Clo…

Java项目实战II基于Java+Spring Boot+MySQL的作业管理系统设计与实现(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者 一、前言 在教育信息化的大潮中&#xff0c;作业管理作为教学过程中的重要环节&#xff0c;其效率与效果直接影…

“吉林一号”宽幅02B系列卫星

离轴四反光学成像系统 1.光学系统参数&#xff1a; 焦距&#xff1a;77.5mm&#xff1b; F/#&#xff1a;7.4&#xff1b; 视场&#xff1a;≥56゜&#xff1b; 光谱范围&#xff1a;400nm&#xff5e;1000nm。 2.说明&#xff1a; 光学系统采用离轴全反射式结构&#xff0c;整…

滑动窗口(6)_找到字符串中所有字母异位词

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 滑动窗口(6)_找到字符串中所有字母异位词 收录于专栏【经典算法练习】 本专栏旨在分享学习算法的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#x1f4…

极狐GitLab CI/CD 功能合集(超详细教程)

极狐GitLab 是 GitLab 在中国的发行版&#xff0c;专门面向中国程序员和企业提供企业级一体化 DevOps 平台&#xff0c;用来帮助用户实现需求管理、源代码托管、CI/CD、安全合规&#xff0c;而且所有的操作都是在一个平台上进行&#xff0c;省事省心省钱。可以一键安装极狐GitL…

Java启动Tomcat: Can‘t load IA 32-bit .dll on a AMD 64-bit platform报错问题解决

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 专栏介绍 在软件开发和日常使用中&#xff0c;BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…