深度学习:神经网络中的损失函数的使用

server/2024/11/24 20:03:54/

深度学习神经网络中的损失函数的使用

损失函数是监督学习中的关键组成部分,用于衡量模型预测值与真实值之间的差异。优化算法(如梯度下降)通过最小化损失函数来调整模型参数,以提高模型的预测精度。以下是几种常用的损失函数及其在PyTorch中的实现和应用的详细解释:

1. L1 损失(绝对误差损失)

L1 损失是一个基于预测值和真实值之间绝对差值的损失函数,常用于回归问题。它有助于提高模型的鲁棒性,尤其是在异常值存在的情况下。

数学表达式

[ L ( y , y ^ ) = ∑ i = 1 n ∣ y i − y ^ i ∣ L(y, \hat{y}) = \sum_{i=1}^n |y_i - \hat{y}_i| L(y,y^)=i=1nyiy^i ]
其中 ( y i y_i yi) 是真实值,( y ^ i \hat{y}_i y^i) 是预测值。

PyTorch 实现
import torch
import torch.nn as nnloss_fn = nn.L1Loss()
y_true = torch.tensor([2, 3, 4, 5], dtype=torch.float)
y_pred = torch.tensor([1.5, 3.5, 3.8, 5.2], dtype=torch.float)
loss = loss_fn(y_pred, y_true)
示例

计算 L1 损失:
[ $L = |2 - 1.5| + |3 - 3.5| + |4 - 3.8| + |5 - 5.2| = 0.5 + 0.5 + 0.2 + 0.2 = 1.4 $]

2. MSE 损失(均方误差损失)

均方误差损失是回归问题中最常用的损失函数之一,计算真实值与预测值之间差值的平方和的均值。它放大了较大误差的影响,使模型更加注重减少大的预测误差。

数学表达式

[ $L(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 KaTeX parse error: Can't use function '\]' in math mode at position 1: \̲]̲ 其中 \(y_iKaTeX parse error: Can't use function '\)' in math mode at position 1: \̲)̲ 是真实值,\(\hat{y}_i$) 是预测值。

PyTorch 实现
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
示例

计算 MSE:
[ L = 1 4 ( ( 2 − 1.5 ) 2 + ( 3 − 3.5 ) 2 + ( 4 − 3.8 ) 2 + ( 5 − 5.2 ) 2 ) = 1 4 ( 0.25 + 0.25 + 0.04 + 0.04 ) = 0.145 L = \frac{1}{4}((2 - 1.5)^2 + (3 - 3.5)^2 + (4 - 3.8)^2 + (5 - 5.2)^2) = \frac{1}{4}(0.25 + 0.25 + 0.04 + 0.04) = 0.145 L=41((21.5)2+(33.5)2+(43.8)2+(55.2)2)=41(0.25+0.25+0.04+0.04)=0.145 ]

3. 交叉熵损失(Cross-Entropy Loss)

交叉熵损失是分类问题中最常用的损失函数之一,特别适用于多类分类问题。它衡量的是预测概率分布与真实分布之间的差异。

数学表达式

[ L = − ∑ c = 1 M y c log ⁡ ( p c ) L = -\sum_{c=1}^M y_c \log(p_c) L=c=1Myclog(pc) ]
其中 ( y c y_c yc) 是如果样本属于类别 ( c c c),则为1,否则为0;( p c p_c pc) 是预测样本属于类别 ( c c c) 的概率。

PyTorch 实现
loss_fn = nn.CrossEntropyLoss()
# 注意:CrossEntropyLoss的输入不应用one-hot编码,且预测值不通过softmax
y_true = torch.tensor([1])  # 类别索引为1
y_pred = torch.tensor([[0.1, 0.6, 0.3]])  # logits
loss = loss_fn(y_pred, y_true)
示例

计算交叉熵损失:
[ L = − ( 0 ⋅ log ⁡ ( 0.1 ) + 1 ⋅ log ⁡ ( 0.6 ) + 0 ⋅ log ⁡ ( 0.3 ) ) = − log ⁡ ( 0.6 ) ≈ 0.51 L = -(0 \cdot \log(0.1) + 1 \cdot \log(0.6) + 0 \cdot \log(0.3)) = -\log(0.6) \approx 0.51 L=(0log(0.1)+1log(0.6)+0log(0.3))=log(0.6)0.51 ]

总结

损失函数是衡量模型性能的重要工具,通过最小化损失,我们可以使模型在特定任务上表现得更好。选择合适的损失函数对于模型的最终性能至关重要,应根据具体任务和数据的性质来选择。在PyTorch中,使用这些损失函数可以直接通过简单的API调用实现,方便模型的训练和优化。


http://www.ppmy.cn/server/144628.html

相关文章

微深节能 平板小车运动监测与控制系统 格雷母线

微深节能的平板小车运动监测与控制系统中的格雷母线,是一种高精度、非接触式的位移测量系统,在平板小车的运动监测与控制中发挥着核心作用。 一、系统组成 该系统主要由以下关键部件组成: 地面电气柜:包含地址jie码器等重要组件&a…

Vue3与Vue2 对比

作者:东方小月 链接:https://juejin.cn/post/7111129583713255461 来源:稀土掘金 采用选项式,组合式,setup语法糖三种方式对比 首先实现一个同样的逻辑(点击切换页面数据)看一下它们直接的区别 更多前端,python知识请移…

手写一个深拷贝工具

背景 在面向对象编程中,对象之间的复制是一个常见的需求。对象的复制通常分为浅拷贝(Shallow Copy)和深拷贝(Deep Copy)两种方式。浅拷贝只复制对象的基本数据类型和引用类型的数据地址,而深拷贝则会递归地…

React Native的`react-native-reanimated`库中的`useAnimatedStyle`钩子来创建一个动画样式

React Native的react-native-reanimated库中的useAnimatedStyle钩子来创建一个动画样式,用于一个滑动视图的每个项目(SliderItem)。useAnimatedStyle钩子允许你根据动画值(在这个例子中是scrollX)来动态地设置组件的样…

react和vue图片懒加载及实现原理

一、实现原理 核心思想: 只有当图片出现在视口中时,才加载图片。利用占位图或占位背景优化用户体验。 实现技术: 监听滚动事件:监听页面滚动,通过计算图片与视口的位置关系,判断是否需要加载图片。Intersec…

Scala学习记录,Array

数组:物理空间上连续的(一个挨一个)优势:根据下标能快速找到元素。 列表:物理空间上不连续(不是一个元素挨着一个元素的)优势:插入元素,删除较快。 Array定义&#xff…

大数据面试题每日练习--HDFS是如何工作的?

HDFS(Hadoop Distributed File System)是一个分布式文件系统,设计用于存储非常大的文件。它的主要工作原理如下: NameNode:管理文件系统的命名空间,维护文件目录树和文件元数据信息。NameNode记录每个文件…

一学就废|Python基础碎片,列表(List)

列表(数组)是一种常见的数据结构,通常,列表的共性操作包括获取、设置、搜索、过滤和排序。以下是对列表的一些常用的操作方法。 基本操作 我们可以在 Python 中操作列表的方法有很多。在我们开始学习这些通用操作之前,以下片段显示了列表最常…