pytorch torch.nn.functional.one_hot函数介绍

news/2024/9/17 7:46:41/ 标签: pytorch, 深度学习, 机器学习

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。

函数签名

torch.nn.functional.one_hot(tensor, num_classes=-1)

参数

  1. tensor:

    • 输入的整数张量。该张量的每个元素都表示一个类别索引。
    • tensor 的数据类型必须是整数类型(如 torch.LongTensor 或 torch.IntTensor)。
  2. num_classes:

    • 输出独热编码向量的长度,即类别的总数。如果设置为默认值 -1,则 num_classes 会自动设置为输入张量中最大值加1,即 max(tensor) + 1
    • 如果指定 num_classes,生成的每个独热向量的长度就是 num_classes,即使某些类别索引可能小于该值。

输出

  • 输出是一个新张量,其中输入张量的每个整数都被转换为一个独热编码向量。
  • 输出张量的形状为:(*input_shape, num_classes),即在输入张量的最后增加一个维度,代表类别的独热编码。

独热编码示例

独热编码是指在一个向量中,只有一个位置是1,其余位置都是0。例如,如果有三个类别,类别0可以表示为 [1, 0, 0],类别1 表示为 [0, 1, 0],类别2 表示为 [0, 0, 1]

示例

示例 1:简单独热编码
import torch
import torch.nn.functional as F# 假设有类别索引 [0, 1, 2]
labels = torch.tensor([0, 1, 2])
one_hot = F.one_hot(labels, num_classes=3)print(one_hot)

输出:

tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1]])

在这里,类别索引 [0, 1, 2] 分别被编码为独热向量 [1, 0, 0][0, 1, 0] 和 [0, 0, 1]

示例 2:自定义类别数量
# 输入类别索引为 [0, 1, 4]
labels = torch.tensor([0, 1, 4])
one_hot = F.one_hot(labels, num_classes=5)print(one_hot)

输出:

tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 0, 0, 1]])

即使 labels 中最大值是 4,指定了 num_classes=5,独热向量的长度为 5。

示例 3:多维输入
# 输入为二维张量
labels = torch.tensor([[0, 1], [2, 3]])
one_hot = F.one_hot(labels, num_classes=4)print(one_hot)

输出:

tensor([[[1, 0, 0, 0],[0, 1, 0, 0]],[[0, 0, 1, 0],[0, 0, 0, 1]]])

输出张量的形状为 (2, 2, 4),即在输入形状 (2, 2) 的基础上,在最后增加了一个维度来表示类别的独热编码。

应用场景

  1. 分类任务: 在神经网络的分类任务中,通常需要将类别标签转换为独热编码。例如在多分类问题中,将标签转换为独热编码后,可以与交叉熵损失函数配合使用。

  2. 序列数据处理: 在自然语言处理任务中,可以使用独热编码将词汇表中的每个单词转换为独热向量,表示该单词在词汇表中的位置。

  3. 距离计算: 在某些算法中,使用独热编码表示类别或索引可以帮助计算不同类别或位置之间的距离。

总结

torch.nn.functional.one_hot 是一个简单但强大的工具,用于将整数标签或类别索引转换为独热编码。它通常用于分类问题的标签预处理,特别是在多类别分类任务中非常有用。


http://www.ppmy.cn/news/1522609.html

相关文章

notepad++软件介绍(含安装包)

Notepad 是一款开源的文本编辑器,主要用于编程和代码编辑。它是一个功能强大的替代品,常常被用来替代 Windows 系统自带的记事本。 Notepad win系统免费下载地址 以下是 Notepad 的一些主要特点和功能: 多语言支持:Notepad 支持多…

Kafka【八】如何保证消息发送的可靠性、重复性、有序性

【1】消息发送的可靠性保证 对于生产者发送的数据,我们有的时候是不关心数据是否已经发送成功的,我们只要发送就可以了。在这种场景中,消息可能会因为某些故障或问题导致丢失,我们将这种情况称之为消息不可靠。虽然消息数据可能会…

proxy代理解决vue中跨域问题

vue.config.js module.exports {...// webpack-dev-server 相关配置devServer: {host: 0.0.0.0,port: port,open: true,proxy: {/api: {target: https://vfadmin.insistence.tech/prod-api,changeOrigin: true,pathRewrite: {//[^ process.env.VUE_APP_BASE_API]: ^/api: / …

【 html+css 绚丽Loading 】000044 两仪穿行轮

前言:哈喽,大家好,今天给大家分享htmlcss 绚丽Loading!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 &#x1f495…

【sql】评估数据迁移复杂度调查汇报240904

难度判断标准: - 高难度:使用多个表(TBL)或有多个join操作的工具 - 低难度:表数量少且没有join操作的简单工具 - 中等难度:介于高低之间,有少量join操作的工具 5. 最后说明不需要仔细…

25届计算机毕业设计:3步打造北部湾助农平台,Java SpringBoot实践

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

AF透明模式/虚拟网线模式组网部署

透明模式组网 实验拓扑 防火墙基本配置 接口配置 eth1 eth3 放通策略 1. 内网用户上班时间(9:00-17:00)不允许看视频、玩游戏及网上购物,其余时 间访问互联网不受限制;(20 分) 应用控制策略 2. 互联…

[论文笔记]RAFT: Adapting Language Model to Domain Specific RAG

引言 今天带来一篇结合RAG和微调的论文:RAFT: Adapting Language Model to Domain Specific RAG。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 本文介绍了检索增强微调(Retrieval Augmented Fine Tunin…

【Impala SQL 造数(一)】

前言 SQL 造数即生成测试数据,一般是编码完成之后的测试阶段所需,测试数据可以用于多种目的,包括测试应用程序的功能、业务场景测试、性能测试、数据恢复测试等。在测试阶段特别是数据类需求,需要很多造数场景,像 Hiv…

尚品汇-支付宝支付同步异步回调实现(四十七)

目录: (1)订单支付码有效时间 (2)支付后回调—同步回调 (3)支付宝回调—异步回调 (1)订单支付码有效时间 (2)支付后回调—同步回调 static修饰…

【Jupyter Notebook】安装与使用

打开Anaconda Navigator点击"Install"(Launch安装前是Install)点击"Launch"点击"File"-"New"-"Notebook"​ 5.点击"Select"选择Python版本 6.输入测试代码并按"Enter+Shift"运行代码: 代码如下: …

考研系列-408真题数据结构篇(18-23)

写在前面 此文章是本人在备考过程中408真题数据结构部分(2018年-2023年)的易错题及相应的知识点整理,后期复习也尝尝用到,对于知识提炼归纳理解起到了很大的作用,分享出来希望帮助到大家~ # 2018年 1.堆的建立(从后往前进行调整) 2.算法(正整数和数组之间关系的建立)

k8s API资源对象ingress

有了Service之后,我们可以访问这个Service的IP(clusterIP)来请求对应的Pod,但是这只能是在集群内部访问。 要想让外部用户访问此资源,可以使用NodePort,即在node节点上暴漏一个端口出来,但是这…

pytorch+深度学习实现图像的神经风格迁移

本文的完整代码和部署教程已上传至本人的GitHub仓库,欢迎各位朋友批评指正! 1.各代码文件详解 1.1 train.py train.py 文件负责训练神经风格迁移模型。 加载内容和风格图片:使用 utils.load_image 函数加载并预处理内容和风格图片。初始化…

Banana Pi BPI-SM9 AI 计算模组采用算能科技BM1688芯片方案设计

产品概述 香蕉派 Banana Pi BPI-SM9 16-ENC-A3 深度学习计算模组搭载算能科技高集成度处理器 BM1688,功耗低、算力强、接口丰富、兼容性好。支持INT4/INT8/FP16/BF16/FP32混合精度计算,可支持 16 路高清视频实时分析,灵活应对图像、语音、自…

Java中等题-摆动序列(力扣)

如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元素的序列也视作摆动序列。 例如, [1, 7, 4, 9, 2, 5] 是一个 摆动序列 &…

数据库锁之行级锁、记录锁、间隙锁和临键锁

1. 行级锁 InnoDB 引擎支持行级锁,而MyISAM 引擎不支持行级锁,只支持表级锁。行级锁是基于索引实现的。 对于普通的select语句,是不会加记录锁的,因为它属于快照读,通过在MVCC中的undo log版本链实现。如果要在查询时对…

Python 安装selenium的办法

之前一直安装python以为要进入python的菜单进行输入 如下 老是提示错误,原来是我搞错了,安装这个直接进入cmd即可 如下 pip install selenium 再用pip list查看一下是否安装成功

git 提交代码由原先账号密码调整为ssh

如果你希望将 Git 提交代码的身份验证方式从用户名和密码切换到 SSH,你需要进行以下几个步骤: 1. 生成 SSH 密钥对 如果你还没有 SSH 密钥对,可以使用以下命令生成一个新的密钥对: ssh-keygen -t rsa -b 4096 -C "your_em…

基于SpringBoot校园快递代取系统

基于springbootvue实现的校园快递代取系统(源码L文ppt)4-049 3系统设计 3.1.1系统结构图 系统结构图可以把杂乱无章的模块按照设计者的思维方式进行调整排序,可以让设计者在之后的添加,修改程序内容…