pytorch小记(十四):pytorch中 nn.Embedding 详解
- PyTorch 中的 nn.Embedding 详解
- 1. 什么是 nn.Embedding?
- 2. nn.Embedding 的基本使用
- 示例 1:基础用法
- 示例 2:处理批次输入
- 3. nn.Embedding 与 nn.Linear 的区别
- 3.1 nn.Embedding
- 3.2 nn.Linear
- 4. nn.Embedding 与 nn.Sequential 的区别
- 5. 应用场景
- 6. 总结
PyTorch 中的 nn.Embedding 详解
在自然语言处理、推荐系统以及其他处理离散输入的任务中,我们常常需要将离散的标识符(例如单词、字符、用户 ID 等)转换为连续的、低维的向量表示。PyTorch 提供了专门的模块——nn.Embedding,用来实现这种“嵌入”操作。本文将详细解释 nn.Embedding 的工作原理、使用方法以及与普通线性层(nn.Linear)和顺序模块(nn.Sequential)的区别,并给出清晰的代码示例。
1. 什么是 nn.Embedding?
nn.Embedding 实际上是一个查找表(lookup table),它内部维护一个矩阵,每一行对应一个离散标识符的向量表示。
- 假设你有一个词汇表,大小为
num_embeddings
,每个词将映射到一个embedding_dim
维的向量上。 - nn.Embedding 会创建一个形状为
[num_embeddings, embedding_dim]
的矩阵。 - 当你输入一个包含单词索引的张量时,模块会直接从这个矩阵中查找出相应行的向量,作为单词的嵌入表示。
这种方式的好处是直接“查找”而非进行繁琐的矩阵乘法计算,既高效又直观。
2. nn.Embedding 的基本使用
示例 1:基础用法
下面的例子展示了如何使用 nn.Embedding 将一组单词索引转换为对应的嵌入向量。
import torch
import torch.nn as nn# 定义一个嵌入层
# 假设词汇表大小为 10,每个单词用 5 维向量表示
embedding = nn.Embedding(num_embeddings=10, embedding_dim=5)# 打印嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
# 输出: torch.Size([10, 5])# 定义一个包含单词索引的张量,例如 [3, 7, 1]。索引 embedding 表中[3],[7],[1]行
indices = torch.tensor([3, 7, 1])# 使用嵌入层查找对应的嵌入向量
embedded_vectors = embedding(indices)
print("查找到的嵌入向量:")
print(embedded_vectors)
说明:
- 输入是一个包含索引
[3, 7, 1]
的 1D 张量,输出是一个形状为[3, 5]
的张量。 - 每一行就是词汇表中对应索引的嵌入向量。
示例 2:处理批次输入
在实际任务中,我们通常一次处理多个样本。例如,一个批次中包含多个句子,每个句子由若干单词索引组成。下面的例子展示了如何处理批次数据。
# 假设有一个批次,包含 2 个样本,每个样本包含 4 个单词索引
'''
对应原数据的
[[[1],[2],[3],[4]],[[5],[6],[7],[8]]]
'''
batch_indices = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]
])# 使用嵌入层查找嵌入向量
batch_embeddings = embedding(batch_indices)print("批次嵌入向量形状:", batch_embeddings.shape)
# 输出形状: torch.Size([2, 4, 5])
说明:
- 输入的
batch_indices
形状为[2, 4]
,表示 2 个样本,每个样本 4 个单词索引。 - 输出为
[2, 4, 5]
,每个单词索引转换成 5 维嵌入向量。
3. nn.Embedding 与 nn.Linear 的区别
虽然 nn.Embedding 和 nn.Linear 都涉及到矩阵的操作,但二者解决的问题大不相同。
3.1 nn.Embedding
- 用途:专门用于将离散的索引(如单词 ID)转换为连续的向量表示,是一种查找表操作。
- 输入:通常为整数索引。
- 输出:直接返回查找表中对应的向量,效率高,不进行额外的计算。
3.2 nn.Linear
- 用途:用于实现线性变换,即对输入做矩阵乘法加上偏置,计算公式为 y = x W T + b y = xW^T + b y=xWT+b。
- 输入:需要连续数值的张量。
- 应用:若要模拟嵌入操作,需要先将整数索引转换成 one-hot 编码,再通过 nn.Linear 进行计算,这样既低效又不直观。
总结:
- 使用 nn.Embedding 更直接、更高效,因为它只进行查找操作;
- nn.Linear 则用于对连续特征进行线性变换。
4. nn.Embedding 与 nn.Sequential 的区别
- nn.Sequential 是一个模块容器,用于按顺序组合多个层,适用于前向传播流程固定的情况。
- nn.Embedding 则是一个具体的层,用于实现查找表功能。
- 在模型中,我们通常将 nn.Embedding 放在最前面,将离散输入转换为连续向量,再结合 nn.Sequential 里的其他层进行进一步处理。
例如,在 NLP 模型中常常这样使用:
class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim):super(TextModel, self).__init__()# 使用 nn.Embedding 将单词索引映射为嵌入向量self.embedding = nn.Embedding(vocab_size, embed_dim)# 使用 nn.Sequential 组合后续的线性层self.fc = nn.Sequential(nn.Linear(embed_dim, 10),nn.ReLU(),nn.Linear(10, 2))def forward(self, x):# x 的形状可能为 (batch_size, sequence_length)x_embed = self.embedding(x) # 变为 (batch_size, sequence_length, embed_dim)# 对嵌入向量进行池化,变为 (batch_size, embed_dim)x_pool = x_embed.mean(dim=1)out = self.fc(x_pool)return out# 假设词汇表大小 100,嵌入维度 8
model = TextModel(vocab_size=100, embed_dim=8)
在这个例子中,nn.Embedding 将离散单词转换为连续向量,而 nn.Sequential 则定义了后续的前向传播步骤。
5. 应用场景
nn.Embedding 常用于:
- 自然语言处理(NLP):将单词、子词、字符等离散输入转换为低维向量表示,为后续的 RNN、Transformer 模型提供输入。
- 推荐系统:将用户 ID、商品 ID 映射为嵌入向量,用于捕捉用户和物品之间的相似性。
- 图神经网络:将节点或边的离散标签转换为连续向量表示。
6. 总结
- nn.Embedding 是一个查找表,用于将离散索引映射为连续向量。
- 它的输入通常是整数张量,输出是对应的嵌入向量。
- 与 nn.Linear 相比,nn.Embedding 不需要进行大量的计算,只是直接查找,所以更高效。
- nn.Embedding 经常与 nn.Sequential 结合使用:先将离散数据转换为嵌入向量,再通过连续层进行处理。
通过以上详细解释和分步代码示例,希望大家能对 nn.Embedding 有一个全面的理解,并能在实际项目中正确使用它来提升模型的表现。
🚀 写在最后:
利用 nn.Embedding,你可以轻松将离散数据转换为高质量的连续表示,这在各种深度学习任务中都是至关重要的!