Pytorch的F.cross_entropy交叉熵函数

devtools/2025/2/24 23:22:56/

参考笔记:pytorch的F.cross_entropy交叉熵函数和标签平滑函数_怎么给crossentropyloss添加标签平滑-CSDN博客

先来讲下基本的交叉熵cross_entropy,官网如下:torch.nn.functional.cross_entropy — PyTorch 1.12 documentation

python">torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)loss=F.cross_entropy(input, target)

从官网所给的资料及案例,根据传入参数的数据形式,交叉熵函数的计算方式有两种

目录

1.第一种

2.第二种


1.第一种

N:样本个数,C:类别数

input:{Tensor:(N,C)},存放每个样本对于某种类别的预测值

target:{Tensor:(N,)},存放每个样本的真实类别

python">input = torch.randn(3, 5, requires_grad=True) #input[i][j]表示第i个样本第j个类别的预测值
target = torch.randint(5, (3,), dtype=torch.int64) #target[i]表示第i个样本的类别
loss = F.cross_entropy(input, target)

交叉熵损失的计算公式如下:

交叉熵损失计算公式

  • 第1步:对 input 中每个样本对于各类别的预测值先进行 Softmax ,将预测值转化为每个样本对于某个类别的概率,得到 input_soft Softmax 的计算公式如下所示:

Softmax计算公式

  • 第2步: input_soft 进行 log 运算,记为 input_soft_log
  • 第3步:对 input_soft_log 与真实值 target 进行处理

(1)先将 target 进行 one-hot 编码,即对于每个样本,生成一个长度为 C one-hot 标签向量,其中真实类别值为 1 ,其他类别值为 0 ,如下所示:

python">target = [[0 0 1 0 0],[0 0 0 1 0],[0 0 0 0 1]]

(2)计算第一个样本的损失

P = [0 0 1 0 0]

Q = [-1.9134, -1.2139, -0.1602,  1.1279,  0.4113]

logQ = [-3.6874, -2.9879, -1.9342, -0.6461, -1.3627]

由交叉熵公式可得,第一个样本的损失值计算方法如下:

第一个样本的损失值计算 

同理可得,第二、第三个样本的损失值分别为 - logQ(3)、-logQ(4)

(3)最后计算每个样本损失值的平均值:

Loss =(-(logQ(2) +logQ(3) +logQ(4)))/3


调用 Pytorch 实现的交叉熵函数计算 loss 值:

 自己实现交叉熵损失函数的流程计算 loss 值:

2.第二种

N:样本个数,C:类别数

input:{Tensor:(N,C)},存放每个样本对于某种类别的预测值

target:{Tensor:(N,C)},存放每个样本的对于某种类别的真实概率

交叉熵损失的计算公式如下:

交叉熵损失函数计算公式

计算流程与第一种大致相同,只是不再需要对 taget one-hot 编码,将交叉熵损失函数计算公式中的 P 值用真实概率值代替即可


这里以计算第一个的样本的损失值为例:

P = [0.1062, 0.3173, 0.1361, 0.2172, 0.2232]

Q = [-1.2470,  0.5475, -0.4514, -1.3397, -0.8915]

logQ = [-2.4485, -0.6540, -1.6529, -2.5411, -2.0930]

由交叉熵公式可得,第一个样本的损失值计算方法如下:

第一个样本的损失值计算 

 同理可得,第二、第三个样本的损失值分别为1.5416, 2.0304

最后计算三个样本的平均损失值:

Loss=(1.7715+1.5416+2.0304) \div 3 = 1.7611


自己实现交叉熵损失函数的流程计算 loss 值:


http://www.ppmy.cn/devtools/161449.html

相关文章

LangChain-基础(prompts、序列化、流式输出、自定义输出)

LangChain-基础 我们现在使用的大模型训练数据都是基于历史数据训练出来的,它们都无法处理一些实时性的问题或者一些在训练时为训练到的一些问题,解决这个问题有2种解决方案 基于现有的大模型上进行微调,使得它能适应这些问题(本…

Linux(ubuntu) GPU CUDA 构建Docker镜像

一、创建Dockerfile FROM ubuntu:20.04#非交互式,以快速运行自动化任务或脚本,无需图形界面 ENV DEBIAN_FRONTENDnoninteractive# 安装基础工具 RUN apt-get update && apt-get install -y \curl \wget \git \build-essential \software-proper…

Nginx:服务架构中不可或缺的基础组件

基本所有的服务架构中,Nginx 都是不可或缺的基础组件: HTTP 负载均衡,将请求转发到后端 API 服务器。健康检查,将后端无法服务的节点移除。配置 HTTPS ,增强安全性。静态服务器,随着前后端分离&#xff0c…

《跟李沐学 AI》AlexNet论文逐段精读学习心得 | PyTorch 深度学习实战

前一篇文章,使用 AlexNet 实现图片分类 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于学习 9年后重读深度学习奠基作之一:AlexNet【下】【论文精读】】的心得。 《跟李沐…

Unity 位图字体

下载Bitmap Font Generator BMFont - AngelCode.com 解压后不用安装直接双击使用 提前设置 1、设置Bit depth为32 Options->Export options 2、清空所选字符 因为我们将在后边导入需要的字符。 Edit->Select all chars 先选择所有字符 Edit->Clear all chars i…

Web项目测试专题(七)安全性测试

概述: 安全性测试旨在确保Web应用在设计和实现过程中能够抵御各种安全威胁,保护用户数据和系统资源。 步骤: 身份验证和授权:测试用户登录、权限管理和会话管理机制,确保只有授权用户能够访问特定资源。 数据加密…

Java 阻塞队列:让并发更“懂事”

阻塞队列的常见方法 阻塞队列的一些常用方法就是让你在多线程操作时轻松控制数据流。让我们看几个经典的方法: put(E e) 这个方法会将元素 e 放入队列中。如果队列已满,它会阻塞当前线程直到队列有空间可用。 大家好,今天我们来聊一聊 Jav…

【落羽的落羽 数据结构篇】栈和队列

文章目录 一、栈1. 概念2. 栈操作2.1 定义栈结构2.2 栈的初始化2.3 入栈2.4 出栈2.5 取栈顶元素 3. 栈的使用实例 二、队列1. 概念2. 队列操作2.1 定义队列结构2.2 入队列2.3 出队列2.4 销毁队列 三、用队列实现栈四、用栈实现队列 一、栈 1. 概念 栈(stack&#…