Pytorch中的ebmedding到底怎么理解?

embedded/2025/3/3 15:57:06/

在 PyTorch 中,nn.Embedding 是一个用于处理离散符号映射到连续向量空间的模块。它通常用于自然语言处理(NLP)任务(如词嵌入)、处理分类特征,或任何需要将离散索引转换为密集向量的场景。


核心理解

  1. 功能

    • 将离散的整数索引(例如单词的索引、类别ID)映射为固定维度的连续向量。
    • 这些向量是可学习的参数,在训练过程中通过反向传播优化。
  2. 参数

    • num_embeddings:词汇表的大小(有多少个唯一的符号/类别)。
    • embedding_dim:每个符号对应的向量维度。
    • 例如:nn.Embedding(1000, 128) 表示将 1000 个符号映射到 128 维的向量空间。
  3. 输入与输出

    • 输入:一个整数张量,形状为 (*)(可以是任意维度,通常是 [batch_size, sequence_length])。
    • 输出:形状为 (*, embedding_dim) 的张量。例如,输入形状为 [2, 3],输出为 [2, 3, 128]

工作原理

  1. 内部权重矩阵

    • nn.Embedding 内部维护一个形状为 (num_embeddings, embedding_dim) 的权重矩阵。
    • 当输入索引 i 时,输出是该矩阵的第 i 行(即 weight[i])。
  2. 类比 One-Hot + 全连接层

    • 可以理解为对输入进行 One-Hot 编码,然后通过一个 无偏置的全连接层
    • 例如,输入 3 会转换为一个 One-Hot 向量 [0,0,0,1,0,...],再与权重矩阵相乘,直接取出第 3 行的向量。
    • 但实际实现是高效的直接索引查找,避免了显式的 One-Hot 计算。

使用示例

import torch
import torch.nn as nn# 定义 Embedding 层:10 个符号,每个符号映射到 3 维向量
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)# 输入:形状为 [2, 4] 的整数张量(例如,两个样本,每个样本长度为4)
input_indices = torch.LongTensor([[1,2,4,5], [4,3,2,9]])# 输出:形状为 [2, 4, 3]
output = embedding(input_indices)
print(output)

关键特性

  1. 可学习的参数

    • 通过 embedding.weight 可以访问或修改权重矩阵(例如加载预训练词向量)。
    • 默认初始化:权重矩阵的值从正态分布 N(0,1) 中随机采样。
  2. 填充索引(Padding)

    • 通过 padding_idx 参数指定填充位置的索引(例如 padding_idx=0),使该位置的向量在训练中不更新。
  3. 冻结权重

    • 通过 embedding.weight.requires_grad_(False) 可以冻结参数,使其不参与训练。

应用场景

  1. 词嵌入(Word Embedding)

    vocab_size = 5000  # 词汇表大小
    embedding_dim = 300
    embedding_layer = nn.Embedding(vocab_size, embedding_dim)
    
  2. 类别特征嵌入

    • 处理分类特征时,将类别ID转换为向量(类似One-Hot的密集版本)。
  3. 推荐系统

    • 用户ID、物品ID的嵌入表示。

注意事项

  1. 输入范围

    • 输入的索引必须在 [0, num_embeddings-1] 范围内,否则会报错。
  2. 梯度传播

    • 只有实际被用到的索引对应的向量会更新梯度(未被使用的索引不影响模型参数)。
  3. 预训练初始化

    • 可以加载预训练的权重(如 Word2Vec、GloVe):
      embedding_layer.weight.data.copy_(torch.from_numpy(pretrained_matrix))
      

总结

nn.Embedding 是 PyTorch 中实现嵌入操作的核心模块,它将离散符号映射到连续的语义空间,是处理符号数据的基础工具。通过训练,模型可以自动学习符号之间的语义关系(例如相似性)。


http://www.ppmy.cn/embedded/169650.html

相关文章

计算机毕业设计Python+DeepSeek-R1大模型期货价格预测分析 期货价格数据分析可视化预测系 统 量化交易大数据 机器学习 深度学习

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…

认知动力学视角下的生命优化系统:多模态机器学习框架的哲学重构

认知动力学视角下的生命优化系统:多模态机器学习框架的哲学重构 一、信息熵与生命系统的耗散结构 在热力学第二定律框架下,生命系统可视为负熵流的耗散结构: d S d i S d e S dS d_iS d_eS dSdi​Sde​S 其中 d i S d_iS di​S为内部熵…

next实现原理

Next.js 是一个基于 React 的 服务器端渲染(SSR) 和 静态生成(SSG) 框架,它的实现原理涉及多个关键技术点,包括 服务端渲染(SSR)、静态生成(SSG)、客户端渲染…

《Effective Objective-C》阅读笔记(下)

目录 内存管理 理解引用计数 引用计数工作原理 自动释放池 保留环 以ARC简化引用计数 使用ARC时必须遵循的方法命名规则 变量的内存管理语义 ARC如何清理实例变量 在dealloc方法中只释放引用并解除监听 编写“异常安全代码”时留意内存管理问题 以弱引用避免保留环 …

基于ArcGIS Pro、Python、USLE、INVEST模型等多技术融合的生态系统服务构建生态安全格局高阶应用

文字目录 前言第一章、生态安全评价理论及方法介绍一、生态安全评价简介二、生态服务能力简介三、生态安全格局构建研究方法简介 第二章、平台基础一、ArcGIS Pro介绍二、Python环境配置 第三章、数据获取与清洗一、数据获取:二、数据预处理(ArcGIS Pro及…

Chapter 4 Noise performance of elementary transistor stages

Chapter 4 Noise performance of elementary transistor stages 在介绍运放之前, 这一章介绍噪声 噪声是看RMS电压, 即Root-Mean-Square Voltage V R M S 1 T ∫ 0 T V ( t ) 2 d t V P 2 V_{RMS}\sqrt{\frac{1}{T}\int_{0}^{T}V(t)^2dt}\frac{V_P}{\sqrt{2}} VRMS​T1​∫…

如何长期保存数据(不包括云存储)最安全有效?

互联网各领域资料分享专区(不定期更新): Sheet 前言 这个问题需要考虑多个方面,比如存储介质的寿命、数据完整性、访问的便捷性,还有成本等因素。长期保存的话,存储介质的耐久性很重要。比如常见的硬盘、SSD、光盘、磁带等,各有优缺点。机械硬盘(HDD)的寿命一般在3-5年,…

IT安全运维指南:手册、工具与资源速览

1. IT安全运维手册 飞塔防火墙手册:https://handbook.fortinet.com.cn/ 亿邮邮箱系统手册:https://mail.eyou.net/?qhelp 深信服上网行为管理手册:https://support.sangfor.com.cn/productDocument/read?product_id22&version_id943 …