CrossEntropy(交叉熵损失函数pytorch)

news/2024/11/14 14:17:00/

介绍

crossentropy损失函数主要用于多分类任务。它计算了模型输出与真实标签之间的交叉熵损失,可以作为模型优化的目标函数。

在多分类任务中,每个样本有多个可能的类别,而模型输出的是每个样本属于每个类别的概率分布。交叉熵损失函数可以度量模型输出的概率分布与真实标签之间的距离,从而指导模型优化。

Pytorch库的用法

class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

参数介绍

  • weight, 为一维张量,具体的大小为M,M为样本的标签数量,代表赋予的类别的权重
  • ignore_index,int类型数据,用于指定忽略某个类别的索引。默认为 -100,表示不忽略任何类别。
  • reduction:指定损失函数的计算方式。可选项包括:‘none’(不返回每个样本的损失值)、‘mean’(返回每个样本的平均损失值)、‘sum’(返回每个样本的总损失值)。

具体使用例子

import torch
import torch.nn as nn
batch_size = 32
class_num = 3
inputs = torch.rand(batch_size, class_num) # [32, 3]
target = torch.randint(0, 3, size=(batch_size,)) # [32]
softmax = nn.Softmax()
inputs = softmax(inputs)
loss_func = nn.CrossEntropyLoss()
predict = loss_func(inputs, target)
print(predict)
# 需要注意的是需要先定义损失函数/softmax函数,而且在设置size的时候需要额外多加入一个括号

模型输入

  • inputs:模型的输出,形状为 (batch_size, class_num),class_num 表示类别数。可以看作是每个样本被分到每个类别的概率值(这里一般需要用softmax等进行进行转化)。
  • target:真实标签,形状为 (batch_size),其中每个元素的值是样本所属的类别索引。

计算方法

二分类交叉熵损失函数

L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ⁡ ( p i ) + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i ) ] L=\frac{1}{N} \sum_i L_i=\frac{1}{N} \sum_i-\left[y_i \cdot \log \left(p_i\right)+\left(1-y_i\right) \cdot \log \left(1-p_i\right)\right] L=N1iLi=N1i[yilog(pi)+(1yi)log(1pi)]

参数介绍

  • N,代表了N个样本
  • L i L_{i} Li,为某个样本的对应损失函数的值
  • y i y_{i} yi为样本的label数值,如果是就为1,不是就为0
  • p i p_{i} pi为模型输出的概率分布(数值),位于0-1之间

多分类交叉熵损失函数

L = 1 N ∑ i L i = − 1 N ∑ i ∑ c = 1 M y i c log ⁡ ( p i c ) L=\frac{1}{N} \sum_i L_i=-\frac{1}{N} \sum_i \sum_{c=1}^M y_{i c} \log \left(p_{i c}\right) L=N1iLi=N1ic=1Myiclog(pic)

参数介绍

  • N,代表了N个样本
  • M,为M个种类或者类别
  • y i c y_{ic} yic,代表的是第i个样本对于第C个种类的label数值
  • p i c p_{ic} pic,代表的是第i个样本对于第C个种类的概率分布/(数值)

优点

在使用反向传播,梯度下降优化的时候,模型取决于学习率(learning rate)和偏导值,而且学习率我们可以手工设置,因此我们从偏导数出发。偏导数越大,证明模型的效果越差,但也会让学习的速率越快,因此使用交叉熵损失函数,在模型效果较差的时候学习速度会较快,更容易收敛。

缺点

注重的任务为分类,更容易学习不同类别之间的信息,较为关心正确预测概率的准确性,容易忽略其他的标签的差异和联系。学习得到的特征较为松散。

参考

损失函数|交叉熵损失函数(知乎)
维基百科交叉熵介绍


http://www.ppmy.cn/news/714788.html

相关文章

惠普暗夜精灵3plus配置ubuntu18.0.4、cuda9.0、cudnn7.0、anaconda(python2.7)、tensorflow-gpu1.8、keras、opencv等

一、ubuntu18.0.4 新的版本内核为22,,安装用u盘启动,分区为/,swap,/boot。桌面版画面更精致,ubuntu图形界面越来越漂亮。自带集成显卡。 二、cuda9.0 先安装显卡驱动,…

上半年结束,下半年继续冲!

前言: 这周直播也把雷神写的Ffmpeg推流器讲解完了,而一同时,一转眼间,2023年已经过半,正式进入了下半年: 因为上半年已经开始在做解析Ffmpeg 最新版本的源码,所以下半年,我会继续坚持讲解Ffmpeg…

Flutter 引入包import的各种含义,及常用命名规范

一、import含义 import dart:xxx; 引入Dart标准库 import xxx/xxx.dart;引入相对路径的Dart文件 import package:xxx/xxx.dart;引入Pub仓库pub.dev(或者pub.flutter-io.cn)中的第三方库 import package:project/xxx/xxx.dart;引入自定义的dart文件 impo…

基于html的漫画静态网站设计

目 录 摘 要 1 第一章 引言 2 1.1研究背景 2 1.2研究意义 2 第二章 漫画网站设计概述 2 2.1选题的目的和意义 3 2.2课题研究的主要简介 4 第三章 具体实现与分析 4 3.1静态设计 4 3.2站点的建设与收集素材 4 3.2.1创建本地站点的具体操作步骤如下: 4 3.2.2收集素材&…

python接口自动化(十二)--https请求(SSL)(详解)

简介 本来最新的requests库V2.13.0是支持https请求的,但是一般写脚本时候,我们会用抓包工具fiddler,这时候会 报:requests.exceptions.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:590) 小编…

TextMining day1 电力设备运维过程中的短文本挖掘框架

电力设备运维过程中的短文本挖掘框架 III. 短文本挖掘框架的具体设计A. 预处理模块的具体设计B. 数据清洗模块的具体设计C. 表示模块的具体设计D. 数据分析模块的具体设计 IV. 案例研究A. 基于文本分类的缺陷程度判断B. 基于文本检索的缺陷处理决策 V. 结论 预处理 首先&#x…

Agilent/HP 8753D网络分析仪 30kHz-6GHz

性能特点: *频率范围:30kHz~3或6GHz *带有固态转换的集成化S参数测试装置 *达110dB的动态范围 *快的测量速度和数据传递速率 *大屏幕LCD显示器加上供外部监视器用的VGA输出 *同时显示所有4个S参数 *将仪器状态和数据存储/调用到内置软盘驱动…

一文了解Docker之网络模型

目录 1.Docker网络 1.1 Docker网络模型概述 1.2 Docker网络驱动程序 1.2.1 host模式 1.2.2 bridge模式 1.2.3 container模式 1.2.4 none模式 1.3 Docker网络命令示例 1.3.1 创建一个自定义网络 1.3.2 列出所有网络 1.3.3 连接容器到网络 1.3.4 断开容器与网络的连接…