pytorch小记(十四):pytorch中 nn.Embedding 详解

devtools/2025/3/19 12:29:53/

pytorch小记(十四):pytorch中 nn.Embedding 详解

  • PyTorch 中的 nn.Embedding 详解
    • 1. 什么是 nn.Embedding?
    • 2. nn.Embedding 的基本使用
      • 示例 1:基础用法
      • 示例 2:处理批次输入
    • 3. nn.Embedding 与 nn.Linear 的区别
      • 3.1 nn.Embedding
      • 3.2 nn.Linear
    • 4. nn.Embedding 与 nn.Sequential 的区别
    • 5. 应用场景
    • 6. 总结


PyTorch 中的 nn.Embedding 详解

在自然语言处理、推荐系统以及其他处理离散输入的任务中,我们常常需要将离散的标识符(例如单词、字符、用户 ID 等)转换为连续的、低维的向量表示。PyTorch 提供了专门的模块——nn.Embedding,用来实现这种“嵌入”操作。本文将详细解释 nn.Embedding 的工作原理、使用方法以及与普通线性层(nn.Linear)和顺序模块(nn.Sequential)的区别,并给出清晰的代码示例。


1. 什么是 nn.Embedding?

nn.Embedding 实际上是一个查找表(lookup table),它内部维护一个矩阵,每一行对应一个离散标识符的向量表示。

  • 假设你有一个词汇表,大小为 num_embeddings,每个词将映射到一个 embedding_dim 维的向量上。
  • nn.Embedding 会创建一个形状为 [num_embeddings, embedding_dim] 的矩阵。
  • 当你输入一个包含单词索引的张量时,模块会直接从这个矩阵中查找出相应行的向量,作为单词的嵌入表示。

这种方式的好处是直接“查找”而非进行繁琐的矩阵乘法计算,既高效又直观。


2. nn.Embedding 的基本使用

示例 1:基础用法

下面的例子展示了如何使用 nn.Embedding 将一组单词索引转换为对应的嵌入向量。

import torch
import torch.nn as nn# 定义一个嵌入层
# 假设词汇表大小为 10,每个单词用 5 维向量表示
embedding = nn.Embedding(num_embeddings=10, embedding_dim=5)# 打印嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
# 输出: torch.Size([10, 5])# 定义一个包含单词索引的张量,例如 [3, 7, 1]。索引 embedding 表中[3],[7],[1]行
indices = torch.tensor([3, 7, 1])# 使用嵌入层查找对应的嵌入向量
embedded_vectors = embedding(indices)
print("查找到的嵌入向量:")
print(embedded_vectors)

说明:

  • 输入是一个包含索引 [3, 7, 1] 的 1D 张量,输出是一个形状为 [3, 5] 的张量。
  • 每一行就是词汇表中对应索引的嵌入向量。

示例 2:处理批次输入

在实际任务中,我们通常一次处理多个样本。例如,一个批次中包含多个句子,每个句子由若干单词索引组成。下面的例子展示了如何处理批次数据。

# 假设有一个批次,包含 2 个样本,每个样本包含 4 个单词索引
'''
对应原数据的
[[[1],[2],[3],[4]],[[5],[6],[7],[8]]]
'''
batch_indices = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]
])# 使用嵌入层查找嵌入向量
batch_embeddings = embedding(batch_indices)print("批次嵌入向量形状:", batch_embeddings.shape)
# 输出形状: torch.Size([2, 4, 5])

说明:

  • 输入的 batch_indices 形状为 [2, 4],表示 2 个样本,每个样本 4 个单词索引。
  • 输出为 [2, 4, 5],每个单词索引转换成 5 维嵌入向量。

3. nn.Embedding 与 nn.Linear 的区别

虽然 nn.Embedding 和 nn.Linear 都涉及到矩阵的操作,但二者解决的问题大不相同。

3.1 nn.Embedding

  • 用途:专门用于将离散的索引(如单词 ID)转换为连续的向量表示,是一种查找表操作。
  • 输入:通常为整数索引。
  • 输出:直接返回查找表中对应的向量,效率高,不进行额外的计算。

3.2 nn.Linear

  • 用途:用于实现线性变换,即对输入做矩阵乘法加上偏置,计算公式为 y = x W T + b y = xW^T + b y=xWT+b
  • 输入:需要连续数值的张量。
  • 应用:若要模拟嵌入操作,需要先将整数索引转换成 one-hot 编码,再通过 nn.Linear 进行计算,这样既低效又不直观。

总结

  • 使用 nn.Embedding 更直接、更高效,因为它只进行查找操作;
  • nn.Linear 则用于对连续特征进行线性变换。

4. nn.Embedding 与 nn.Sequential 的区别

  • nn.Sequential 是一个模块容器,用于按顺序组合多个层,适用于前向传播流程固定的情况。
  • nn.Embedding 则是一个具体的层,用于实现查找表功能。
  • 在模型中,我们通常将 nn.Embedding 放在最前面,将离散输入转换为连续向量,再结合 nn.Sequential 里的其他层进行进一步处理。

例如,在 NLP 模型中常常这样使用:

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim):super(TextModel, self).__init__()# 使用 nn.Embedding 将单词索引映射为嵌入向量self.embedding = nn.Embedding(vocab_size, embed_dim)# 使用 nn.Sequential 组合后续的线性层self.fc = nn.Sequential(nn.Linear(embed_dim, 10),nn.ReLU(),nn.Linear(10, 2))def forward(self, x):# x 的形状可能为 (batch_size, sequence_length)x_embed = self.embedding(x)  # 变为 (batch_size, sequence_length, embed_dim)# 对嵌入向量进行池化,变为 (batch_size, embed_dim)x_pool = x_embed.mean(dim=1)out = self.fc(x_pool)return out# 假设词汇表大小 100,嵌入维度 8
model = TextModel(vocab_size=100, embed_dim=8)

在这个例子中,nn.Embedding 将离散单词转换为连续向量,而 nn.Sequential 则定义了后续的前向传播步骤。


5. 应用场景

nn.Embedding 常用于:

  • 自然语言处理(NLP):将单词、子词、字符等离散输入转换为低维向量表示,为后续的 RNN、Transformer 模型提供输入。
  • 推荐系统:将用户 ID、商品 ID 映射为嵌入向量,用于捕捉用户和物品之间的相似性。
  • 图神经网络:将节点或边的离散标签转换为连续向量表示。

6. 总结

  • nn.Embedding 是一个查找表,用于将离散索引映射为连续向量。
  • 它的输入通常是整数张量,输出是对应的嵌入向量。
  • 与 nn.Linear 相比,nn.Embedding 不需要进行大量的计算,只是直接查找,所以更高效。
  • nn.Embedding 经常与 nn.Sequential 结合使用:先将离散数据转换为嵌入向量,再通过连续层进行处理。

通过以上详细解释和分步代码示例,希望大家能对 nn.Embedding 有一个全面的理解,并能在实际项目中正确使用它来提升模型的表现。

🚀 写在最后
利用 nn.Embedding,你可以轻松将离散数据转换为高质量的连续表示,这在各种深度学习任务中都是至关重要的!


http://www.ppmy.cn/devtools/168335.html

相关文章

4.从GitHub拉取远程分支到本地

要从GitHub拉取远程分支到本地,可以按以下步骤操作: 1. 方法一:直接拉取并切换到分支 适用场景 远程分支已存在(例如 feature/new-ui),需拉取到本地并自动跟踪。 拉取所有远程分支信息(确保本…

AI+遥感:农作物病虫害实时预警新突破!

引言:一场静悄悄的田间革命 在河南省周口市的麦田里,一架搭载多光谱相机的无人机正在低空盘旋。远在2000公里外的北京国家农业大数据中心,AI系统突然发出警报:北纬33.76、东经114.63区域出现条锈病早期感染迹象,感染面…

【Redis】Redis的数据删除(过期)策略,数据淘汰策略。

如果问到:假如Redis的key过期之后,会立即删除吗? 其实就是想问数据删除(过期)策略。 如果面试官问到:如果缓存过多,内存是有限的,内存被占满了怎么办? 其实就是问:数据的淘汰策略。…

中电金信25/3/18面前笔试(需求分析岗+数据开发岗)

部分相同题目在第二次数据开发岗中不做解析,本次解析来源于豆包AI,正确与否有待商榷,本文只提供一个速查与知识点的补充。 一、需求分析 第1题,单选题,Hadoop的核心组件包括HDFS和以下哪个? MapReduce Spark Storm…

android开发:组件事件汇总

在 Android 开发中,Java 文件中有许多组件事件可以处理用户交互。以下是一些常见的组件事件及其用途和示例: 1. 点击事件 (Click) 用于处理用户点击控件的操作。 示例代码: Button button findViewById(R.id.button); button.setOnClickL…

【GNN】GAT

消息传递 层数越多,聚合更多的消息

Mysql表的简单操作

🏝️专栏:Mysql_猫咪-9527的博客-CSDN博客 🌅主页:猫咪-9527-CSDN博客 “欲穷千里目,更上一层楼。会当凌绝顶,一览众山小。” 目录 3.1 创建表 3.2 查看表结构 3.3 修改表 1. 添加字段 2. 修改字段 3…

202年充电计划——自学手册 网络安全(黑客技术)

🤟 基于入门网络安全/黑客打造的:👉黑客&网络安全入门&进阶学习资源包 前言 什么是网络安全 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“…