pytorch小记(十四):pytorch中 nn.Embedding 详解

embedded/2025/3/26 18:29:43/

pytorch小记(十四):pytorch中 nn.Embedding 详解

  • PyTorch 中的 nn.Embedding 详解
    • 1. 什么是 nn.Embedding?
    • 2. nn.Embedding 的基本使用
      • 示例 1:基础用法
      • 示例 2:处理批次输入
    • 3. nn.Embedding 与 nn.Linear 的区别
      • 3.1 nn.Embedding
      • 3.2 nn.Linear
    • 4. nn.Embedding 与 nn.Sequential 的区别
    • 5. 应用场景
    • 6. 总结


PyTorch 中的 nn.Embedding 详解

在自然语言处理、推荐系统以及其他处理离散输入的任务中,我们常常需要将离散的标识符(例如单词、字符、用户 ID 等)转换为连续的、低维的向量表示。PyTorch 提供了专门的模块——nn.Embedding,用来实现这种“嵌入”操作。本文将详细解释 nn.Embedding 的工作原理、使用方法以及与普通线性层(nn.Linear)和顺序模块(nn.Sequential)的区别,并给出清晰的代码示例。


1. 什么是 nn.Embedding?

nn.Embedding 实际上是一个查找表(lookup table),它内部维护一个矩阵,每一行对应一个离散标识符的向量表示。

  • 假设你有一个词汇表,大小为 num_embeddings,每个词将映射到一个 embedding_dim 维的向量上。
  • nn.Embedding 会创建一个形状为 [num_embeddings, embedding_dim] 的矩阵。
  • 当你输入一个包含单词索引的张量时,模块会直接从这个矩阵中查找出相应行的向量,作为单词的嵌入表示。

这种方式的好处是直接“查找”而非进行繁琐的矩阵乘法计算,既高效又直观。


2. nn.Embedding 的基本使用

示例 1:基础用法

下面的例子展示了如何使用 nn.Embedding 将一组单词索引转换为对应的嵌入向量。

import torch
import torch.nn as nn# 定义一个嵌入层
# 假设词汇表大小为 10,每个单词用 5 维向量表示
embedding = nn.Embedding(num_embeddings=10, embedding_dim=5)# 打印嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
# 输出: torch.Size([10, 5])# 定义一个包含单词索引的张量,例如 [3, 7, 1]。索引 embedding 表中[3],[7],[1]行
indices = torch.tensor([3, 7, 1])# 使用嵌入层查找对应的嵌入向量
embedded_vectors = embedding(indices)
print("查找到的嵌入向量:")
print(embedded_vectors)

说明:

  • 输入是一个包含索引 [3, 7, 1] 的 1D 张量,输出是一个形状为 [3, 5] 的张量。
  • 每一行就是词汇表中对应索引的嵌入向量。

示例 2:处理批次输入

在实际任务中,我们通常一次处理多个样本。例如,一个批次中包含多个句子,每个句子由若干单词索引组成。下面的例子展示了如何处理批次数据。

# 假设有一个批次,包含 2 个样本,每个样本包含 4 个单词索引
'''
对应原数据的
[[[1],[2],[3],[4]],[[5],[6],[7],[8]]]
'''
batch_indices = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]
])# 使用嵌入层查找嵌入向量
batch_embeddings = embedding(batch_indices)print("批次嵌入向量形状:", batch_embeddings.shape)
# 输出形状: torch.Size([2, 4, 5])

说明:

  • 输入的 batch_indices 形状为 [2, 4],表示 2 个样本,每个样本 4 个单词索引。
  • 输出为 [2, 4, 5],每个单词索引转换成 5 维嵌入向量。

3. nn.Embedding 与 nn.Linear 的区别

虽然 nn.Embedding 和 nn.Linear 都涉及到矩阵的操作,但二者解决的问题大不相同。

3.1 nn.Embedding

  • 用途:专门用于将离散的索引(如单词 ID)转换为连续的向量表示,是一种查找表操作。
  • 输入:通常为整数索引。
  • 输出:直接返回查找表中对应的向量,效率高,不进行额外的计算。

3.2 nn.Linear

  • 用途:用于实现线性变换,即对输入做矩阵乘法加上偏置,计算公式为 y = x W T + b y = xW^T + b y=xWT+b
  • 输入:需要连续数值的张量。
  • 应用:若要模拟嵌入操作,需要先将整数索引转换成 one-hot 编码,再通过 nn.Linear 进行计算,这样既低效又不直观。

总结

  • 使用 nn.Embedding 更直接、更高效,因为它只进行查找操作;
  • nn.Linear 则用于对连续特征进行线性变换。

4. nn.Embedding 与 nn.Sequential 的区别

  • nn.Sequential 是一个模块容器,用于按顺序组合多个层,适用于前向传播流程固定的情况。
  • nn.Embedding 则是一个具体的层,用于实现查找表功能。
  • 在模型中,我们通常将 nn.Embedding 放在最前面,将离散输入转换为连续向量,再结合 nn.Sequential 里的其他层进行进一步处理。

例如,在 NLP 模型中常常这样使用:

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim):super(TextModel, self).__init__()# 使用 nn.Embedding 将单词索引映射为嵌入向量self.embedding = nn.Embedding(vocab_size, embed_dim)# 使用 nn.Sequential 组合后续的线性层self.fc = nn.Sequential(nn.Linear(embed_dim, 10),nn.ReLU(),nn.Linear(10, 2))def forward(self, x):# x 的形状可能为 (batch_size, sequence_length)x_embed = self.embedding(x)  # 变为 (batch_size, sequence_length, embed_dim)# 对嵌入向量进行池化,变为 (batch_size, embed_dim)x_pool = x_embed.mean(dim=1)out = self.fc(x_pool)return out# 假设词汇表大小 100,嵌入维度 8
model = TextModel(vocab_size=100, embed_dim=8)

在这个例子中,nn.Embedding 将离散单词转换为连续向量,而 nn.Sequential 则定义了后续的前向传播步骤。


5. 应用场景

nn.Embedding 常用于:

  • 自然语言处理(NLP):将单词、子词、字符等离散输入转换为低维向量表示,为后续的 RNN、Transformer 模型提供输入。
  • 推荐系统:将用户 ID、商品 ID 映射为嵌入向量,用于捕捉用户和物品之间的相似性。
  • 图神经网络:将节点或边的离散标签转换为连续向量表示。

6. 总结

  • nn.Embedding 是一个查找表,用于将离散索引映射为连续向量。
  • 它的输入通常是整数张量,输出是对应的嵌入向量。
  • 与 nn.Linear 相比,nn.Embedding 不需要进行大量的计算,只是直接查找,所以更高效。
  • nn.Embedding 经常与 nn.Sequential 结合使用:先将离散数据转换为嵌入向量,再通过连续层进行处理。

通过以上详细解释和分步代码示例,希望大家能对 nn.Embedding 有一个全面的理解,并能在实际项目中正确使用它来提升模型的表现。

🚀 写在最后
利用 nn.Embedding,你可以轻松将离散数据转换为高质量的连续表示,这在各种深度学习任务中都是至关重要的!


http://www.ppmy.cn/embedded/174691.html

相关文章

针对 pdf.mjs 文件因 MIME 类型错误导致的 Failed to load module script 问题解决方案

Failed to load module script: Expected a JavaScript module script but the server responded with a MIME type of “application/octet-stream”. Strict MIME type checking is enforced for module scripts per HTML spec. pdf.mjs 这种问题该如何处理 nginx 针对 pdf.…

stm32h743一个定时器实现秒和微妙延时

一、定时器溢出公式 Tout(溢出时间)(ARR1)(PSC1)/Tclk 二、确定定时器类型 三种定时器:基础定时器,通用定时器,高级定时器 STM32H743有众多的定时器,其中包括 2 个基本定时器( TIM6 和 TIM7…

INT202 Complexity of Algroithms 算法的复杂度 Pt.2 Search Algorithm 搜索算法

文章目录 1.树的数据结构1.1 有序数据(Ordered Data)1.1.1 有序字典(Ordered Dictonary)1.1.1.1 排序表(Sorted Tables) 1.2 二分查找(Binary Search)1.2.1 二分查找的时间复杂度 1.3 二叉搜索树&#xff0…

正则表达式:文本处理的瑞士军刀

正则表达式:文本处理的瑞士军刀 正则表达式(Regular Expression,简称 Regex)是一种用于匹配、查找和操作文本的强大工具。它通过定义一种特殊的字符串模式,可以快速地在文本中搜索、替换或提取符合特定规则的内容。正…

mysql中的游标是什么?作用是什么?

MySQL中的游标(Cursor)是一种用于在存储过程或函数中逐行处理查询结果集的数据库对象。它的作用类似于编程语言中的迭代器,允许遍历查询返回的多行数据,并对每一行执行特定操作。 游标的作用 逐行处理数据 当查询返回多行结果时&a…

使用Python在Word中创建、读取和删除列表 - 详解

目录 工具与设置 Python在Word中创建列表 使用默认样式创建有序(编号)列表 使用默认样式创建无序(项目符号)列表 创建多级列表 使用自定义样式创建列表 Python读取Word中的列表 Python从Word中删除列表 在Word中&#xff…

洛谷题目:P1018 [NOIP 2000 提高组] 乘积最大 题解

题目传送门: P1018 [NOIP 2000 提高组] 乘积最大 - 洛谷 (luogu.com.cn) 前言: 本题可以使用DP来解决。动态规划的核心思想是将一个复杂的问题分解为多个简单的子问题,并通过求解子问题的最优解来得到原问题的最优解。在本题中&#xff0c…

高频SQL50题 第一天 | 1757. 可回收且低脂的产品、584. 寻找用户推荐人、595. 大的国家、1683. 无效的推文、1148. 文章浏览 I

1757. 可回收且低脂的产品 题目链接:https://leetcode.cn/problems/recyclable-and-low-fat-products/description/?envTypestudy-plan-v2&envIdsql-free-50 状态:已完成 考点:无 select product_id from Products where low_fats Y a…