Pytorch的F.cross_entropy交叉熵函数

news/2025/2/24 4:59:47/

参考笔记:pytorch的F.cross_entropy交叉熵函数和标签平滑函数_怎么给crossentropyloss添加标签平滑-CSDN博客

先来讲下基本的交叉熵cross_entropy,官网如下:torch.nn.functional.cross_entropy — PyTorch 1.12 documentation

python">torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)loss=F.cross_entropy(input, target)

从官网所给的资料及案例,根据传入参数的数据形式,交叉熵函数的计算方式有两种

目录

1.第一种

2.第二种


1.第一种

N:样本个数,C:类别数

input:{Tensor:(N,C)},存放每个样本对于某种类别的预测值

target:{Tensor:(N,)},存放每个样本的真实类别

python">input = torch.randn(3, 5, requires_grad=True) #input[i][j]表示第i个样本第j个类别的预测值
target = torch.randint(5, (3,), dtype=torch.int64) #target[i]表示第i个样本的类别
loss = F.cross_entropy(input, target)

交叉熵损失的计算公式如下:

交叉熵损失计算公式

  • 第1步:对 input 中每个样本对于各类别的预测值先进行 Softmax ,将预测值转化为每个样本对于某个类别的概率,得到 input_soft Softmax 的计算公式如下所示:

Softmax计算公式

  • 第2步: input_soft 进行 log 运算,记为 input_soft_log
  • 第3步:对 input_soft_log 与真实值 target 进行处理

(1)先将 target 进行 one-hot 编码,即对于每个样本,生成一个长度为 C one-hot 标签向量,其中真实类别值为 1 ,其他类别值为 0 ,如下所示:

python">target = [[0 0 1 0 0],[0 0 0 1 0],[0 0 0 0 1]]

(2)计算第一个样本的损失

P = [0 0 1 0 0]

Q = [-1.9134, -1.2139, -0.1602,  1.1279,  0.4113]

logQ = [-3.6874, -2.9879, -1.9342, -0.6461, -1.3627]

由交叉熵公式可得,第一个样本的损失值计算方法如下:

第一个样本的损失值计算 

同理可得,第二、第三个样本的损失值分别为 - logQ(3)、-logQ(4)

(3)最后计算每个样本损失值的平均值:

Loss =(-(logQ(2) +logQ(3) +logQ(4)))/3


调用 Pytorch 实现的交叉熵函数计算 loss 值:

 自己实现交叉熵损失函数的流程计算 loss 值:

2.第二种

N:样本个数,C:类别数

input:{Tensor:(N,C)},存放每个样本对于某种类别的预测值

target:{Tensor:(N,C)},存放每个样本的对于某种类别的真实概率

交叉熵损失的计算公式如下:

交叉熵损失函数计算公式

计算流程与第一种大致相同,只是不再需要对 taget one-hot 编码,将交叉熵损失函数计算公式中的 P 值用真实概率值代替即可


这里以计算第一个的样本的损失值为例:

P = [0.1062, 0.3173, 0.1361, 0.2172, 0.2232]

Q = [-1.2470,  0.5475, -0.4514, -1.3397, -0.8915]

logQ = [-2.4485, -0.6540, -1.6529, -2.5411, -2.0930]

由交叉熵公式可得,第一个样本的损失值计算方法如下:

第一个样本的损失值计算 

 同理可得,第二、第三个样本的损失值分别为1.5416, 2.0304

最后计算三个样本的平均损失值:

Loss=(1.7715+1.5416+2.0304) \div 3 = 1.7611


自己实现交叉熵损失函数的流程计算 loss 值:


http://www.ppmy.cn/news/1574551.html

相关文章

基于 PyQt5 实现分组列表滚动吸顶效果

基于 PyQt5 实现分组列表滚动吸顶效果 在很多应用场景中,例如 QQ 好友列表,我们都需要展示大量分组数据,同时希望在滚动时分组标题始终固定显示在顶部,提升用户体验。本文将详细介绍如何利用 PyQt5 实现类似效果——在滚动区域中…

全局错误处理如何与Vue Router集成?

将全局错误处理与 Vue Router 集成可以确保在应用中处理错误的一致性,并在用户遇到未授权访问或其他错误时提供适当的反馈。以下是如何将全局错误处理与 Vue Router 集成的步骤和示例。 1. 设置全局错误处理 首先,您可以在 main.js 文件中设置全局错误…

【Bert】自然语言(Language Model)入门之---Bert

every blog every motto: Although the world is full of suffering, it is full also of the overcoming of it 0. 前言 对bert进行梳理 论文: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 时间:…

[Android]App生命周期

类似iOS的applicationWillEnterForeground:等方法 以下是使用 Application.ActivityLifecycleCallbacks 接口来监听应用启动和进入前台的示例代码。 创建一个自定义的 ActivityLifecycleCallbacks 首先,创建一个实现 Application.ActivityLifecycleCallbacks 的类…

前端面试之Flex布局:核心机制与高频考点全解析

目录 引言:弹性布局的降维打击 一、Flex布局的本质认知 1. 两大核心维度 2. 容器与项目的权力边界 二、容器属性深度剖析 1. 主轴控制三剑客 2. 交叉轴对齐黑科技 三、项目属性关键要点 1. flex复合属性解密 2. 项目排序魔法 四、六大高频面试场景 1. 经…

挑选出行数足够的excel文件

** 遍历文件夹下的所有excel文件,并将数据量超过指定标准的文件拷贝到指定文件夹中 import os.path import shutil import pandas as pddef copy_excel_files(source_folder, target_folder, row_threshold):if not os.path.exists(target_folder):os.makedirs(ta…

阅读《Vue.js设计与实现》 -- 01

菜鸟最近闲暇(大的项目没开始,别的项目基本没事了),不知道干啥(刷掘金刷多了,感觉文章都写得差不多,不知道学什么,原因如下:沸点),所以开始看《Vu…

3.Docker常用命令

1.Docker启动类命令 1.启动Docker systemctl start docker 2.停止Docker systemctl stop docker 3.重启Docker systemctl restart docker 4.查看Docker状态 systemctl status docker 5.设置开机自启(执行此命令后每次Linux重启后将自启动Docker) systemctl enable do…