[实践应用] 深度学习之损失函数

embedded/2024/9/23 9:28:00/

文章总览:YuanDaiMa2048博客文章总览


深度学习损失函数

    • 1. 回归任务
      • 1.1 均方误差 (MSE)
      • 1.2 平均绝对误差 (MAE)
    • 2. 二分类任务
      • 2.1 二元交叉熵 (Binary Cross-Entropy)
    • 3. 多分类任务
      • 3.1 类别交叉熵 (Categorical Cross-Entropy)
    • 4. 序列生成任务(例如,机器翻译)
        • 4.1 序列交叉熵 (Sequence Cross-Entropy)
    • 5. 回归任务的正则化
        • 5.1 L2 正则化(权重衰减)
    • 其他介绍

机器学习深度学习中,不同的任务使用不同的损失函数来衡量模型的性能。

1. 回归任务

任务: 预测一个连续的数值。

1.1 均方误差 (MSE)

原理: MSE 衡量预测值与实际值之间的平方差的平均值,适用于回归任务。它对异常值敏感。

公式:
MSE = 1 n ∑ i = 1 n ( 预测值 i − 实际值 i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n}(\text{预测值}_i - \text{实际值}_i)^2 MSE=n1i=1n(预测值i实际值i)2

PyTorch 代码:

python">import torch
import torch.nn as nn# 定义均方误差损失函数
mse_loss = nn.MSELoss()

1.2 平均绝对误差 (MAE)

原理: MAE 衡量预测值与实际值之间的绝对差的平均值,对异常值不太敏感。

公式:
MAE = 1 n ∑ i = 1 n ∣ 预测值 i − 实际值 i ∣ \text{MAE} = \frac{1}{n} \sum_{i=1}^{n}|\text{预测值}_i - \text{实际值}_i| MAE=n1i=1n预测值i实际值i

PyTorch 代码:

python">import torch
import torch.nn as nn# 定义平均绝对误差损失函数
mae_loss = nn.L1Loss()

2. 二分类任务

任务: 预测样本属于两个类别中的一个(例如,垃圾邮件分类)。

2.1 二元交叉熵 (Binary Cross-Entropy)

原理: 计算预测概率与实际标签之间的交叉熵,用于二分类任务。

公式:
BCE = − 1 n ∑ i = 1 n [ y i log ⁡ ( p i ) + ( 1 − y i ) log ⁡ ( 1 − p i ) ] \text{BCE} = -\frac{1}{n} \sum_{i=1}^{n} [y_i \log(p_i) + (1 - y_i) \log(1 - p_i)] BCE=n1i=1n[yilog(pi)+(1yi)log(1pi)]

PyTorch 代码:

python">import torch
import torch.nn as nn# 定义二元交叉熵损失函数
bce_loss = nn.BCEWithLogitsLoss()  # 结合了 Sigmoid 激活和 BCE 损失

3. 多分类任务

任务: 预测样本属于多个类别中的一个(例如,手写数字分类)。

3.1 类别交叉熵 (Categorical Cross-Entropy)

原理: 计算预测的概率分布与实际类别之间的交叉熵,用于多分类任务。

公式:
CCE = − 1 n ∑ i = 1 n ∑ k = 1 K y i , k log ⁡ ( p i , k ) \text{CCE} = -\frac{1}{n} \sum_{i=1}^{n} \sum_{k=1}^{K} y_{i,k} \log(p_{i,k}) CCE=n1i=1nk=1Kyi,klog(pi,k)

其中 K K K 是类别数, y i , k y_{i,k} yi,k 是实际类别的 one-hot 编码, p i , k p_{i,k} pi,k 是预测的概率。

PyTorch 代码:

python">import torch
import torch.nn as nn# 定义类别交叉熵损失函数
cross_entropy_loss = nn.CrossEntropyLoss()  # 直接对 logits 应用 Softmax 和计算交叉熵

4. 序列生成任务(例如,机器翻译)

任务: 预测序列中每个位置的类别(例如,翻译每个单词)。

4.1 序列交叉熵 (Sequence Cross-Entropy)

原理: 与多分类交叉熵类似,但应用于序列数据,计算预测序列与实际序列之间的交叉熵。

公式:
Sequence CCE = − 1 n ∑ i = 1 n ∑ t = 1 T ∑ k = 1 K y i , t , k log ⁡ ( p i , t , k ) \text{Sequence CCE} = -\frac{1}{n} \sum_{i=1}^{n} \sum_{t=1}^{T} \sum_{k=1}^{K} y_{i,t,k} \log(p_{i,t,k}) Sequence CCE=n1i=1nt=1Tk=1Kyi,t,klog(pi,t,k)

PyTorch 代码:

python">import torch
import torch.nn as nn# 对于序列生成任务,通常使用 CrossEntropyLoss 处理每个时间步的预测
sequence_cross_entropy_loss = nn.CrossEntropyLoss()

5. 回归任务的正则化

任务: 通过将正则化项添加到损失函数来防止过拟合。

5.1 L2 正则化(权重衰减)

原理: 在损失函数中添加权重的平方和,鼓励较小的权重值。

公式:
Regularized Loss = 原始损失 + λ ∑ j = 1 m W j 2 \text{Regularized Loss} = \text{原始损失} + \lambda \sum_{j=1}^{m} W_j^2 Regularized Loss=原始损失+λj=1mWj2

PyTorch 代码:

python">import torch.optim as optim# 定义模型
model = nn.Linear(10, 1)# 定义优化器,并添加 L2 正则化(weight_decay)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)

其他介绍


http://www.ppmy.cn/embedded/111904.html

相关文章

基于python+django+vue的个性化餐饮管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于pythondjangovueMySQL的视…

ubuntu64位系统无法运行32位程序的解决办法

在 64 位的 Ubuntu 系统上运行 32 位程序时,如果出现问题,可能是由于缺少 32 位库支持。以下步骤可以帮助你解决这一问题: 1. 启用 32 位架构 首先,确保系统支持 32 位架构。你可以通过以下命令添加 32 位架构支持: …

Nginx:高性能的Web服务器与反向代理

在当今的互联网世界中,Web服务器的选择对于网站的性能、稳定性和安全性至关重要。Nginx(发音为“engine X”)凭借其卓越的性能、丰富的功能集和灵活的配置选项,成为了众多网站和应用程序的首选Web服务器和反向代理。本文将深入探讨…

力扣279-完全平方数(Java详细题解)

题目链接:279. 完全平方数 - 力扣(LeetCode) 前情提要: 因为本人最近都来刷dp类的题目所以该题就默认用dp方法来做。 最近刚学完背包,所以现在的题解都是以背包问题为基础再来写的。 如果大家不懂背包问题的话&…

【25.2】C++智能交友系统

Girl类代码补充 对一些成员函数定义的修改 .h文件 #pragma once #include <string> #include <sstream> using namespace std;class Boy;class Girl { public:Girl();Girl(int age, string name, int style);~Girl();int getAge() const;string getName() const…

搜维尔科技:ART光学空间定位虚拟交互工业级光学跟踪系统

ART光学空间定位虚拟交互工业级光学跟踪系统 搜维尔科技&#xff1a;ART光学空间定位虚拟交互工业级光学跟踪系统

Ansible自动化部署kubernetes集群

机器环境介绍 1.1. 机器信息介绍 IP hostname application CPU Memory 192.168.204.129 k8s-master01 etcd&#xff0c;kube-apiserver&#xff0c;kube-controller-manager&#xff0c;kube-scheduler,kubelet,kube-proxy,containerd 2C 4G 192.168.204.130 k8s-w…

SpringBoot整合WebSocket实现消息推送或聊天功能示例

最近在做一个功能&#xff0c;就是需要实时给用户推送消息&#xff0c;所以就需要用到 websocket springboot 接入 websocket 非常简单&#xff0c;只需要下面几个配置即可 pom 文件 <!-- spring-boot-web启动器 --><dependency><groupId>org.springframewo…