nn.Embedding 根据索引生成的向量有权重吗

server/2024/10/10 20:38:32/

import torch
import torch.nn as nn

假设有一个大小为 10x3 的 Embedding 层,其中有 10 个单词,每个单词用一个长度为 3 的向量表示

num_words = 10
embedding_dim = 3

创建 Embedding 层

embedding_layer = nn.Embedding(num_words, embedding_dim)
print(embedding_layer.weight)
embedded_vectors = embedding_layer(torch.LongTensor([4]))
print(embedded_vectors)
embedded_vectors = embedding_layer(torch.LongTensor([5]))
print(embedded_vectors)
embedded_vectors = embedding_layer(torch.LongTensor([6]))
print(embedded_vectors)

在这里插入图片描述
nn.Embedding层单词转向量实测
1.nn.Embedding创建对象embedding_layer
2.可以看到embedding_layer创建完成,其属性weight已经有值了
3.embedding_layer方法传入分别torch.LongTensor([4]),torch.LongTensor([5]),(torch.LongTensor([6])生成的结果就是根据索引值去weight里取值。

打破猜测:
1.原以为embedding_layer里进行的一个乘法,传参*随机权重,如embedded_vectors =torch.LongTensor([4])*W,实际不是,没有乘法
2.实际是nn.Embedding(num_words, embedding_dim)根据参数已经随机生成了所有的向量,之后仅需根据索引取值

原始猜测:
1.由于程序每次重启embedding_layer.weight生成的参数随机,为供断点续训和预测用,这些参数不能每次都随机生成,所以这些应该是要保存在模型中。即断点续训或预测时,embedding层的向量不应是随机生成了,而是读取模型文件中存储的模型参数。
2.embedding_layer.weight参与梯度更新,一开始以为此处没有
3.一开始根据各种信息判断nn.Embedding的内部机制是,或许有一个随机参数,乘以输入单词索引,得到嵌入向量,并且这个参数不参加更新,潜意识是参数不保存。

1.为什么有或许有一个随机参数,乘以输入单词索引,得到嵌入向量这样的理解?因为看着是传入了索引,得到了一个随机向量,合理猜测应该是有个随即参数与传参相乘。所以这里第一步的猜测就错了。首先这个参数的确是有的,机制的实际是随机生成所有的可能的索引的向量,供直接取用。这里参数即嵌入向量2.这个参数不参加更新?这种不参与更新的参数,模型会保存吗?猜测应该不会,如果不会,那每次重启的得到的嵌入向量都变了,怎么供续训和预测用?

对于transformer/bert,网络上的确对nn.Embedding这一步骤的机制讲解不够清晰,不知道嵌入向量是怎么得出的,不知道其中是否有需要训练的参数。
嵌入参数不参加更新这说法主要是来自李宏毅讲的注意力机制那块的误解,说是除了wq,wk,wv参数参与训练,没有别的参数了。这就和续训预测产生了极强的的矛盾,难以判断。

当你创建 nn.Embedding 层时,PyTorch 会随机初始化权重。这些权重在训练过程中会通过反向传播进行更新,以拟合模型的输入和输出数据,确保模型能够更好地进行预测或分类任务。


http://www.ppmy.cn/server/51384.html

相关文章

IPython的使用技巧整理

IPython(Interactive Python)是一个增强的Python交互式解释器,提供了强大的功能和工具,帮助开发者更高效地编写和调试代码。以下是一些IPython的使用技巧,帮助你更好地利用这个工具: 1. 安装IPython IPyth…

Python网络安全项目开发实战,怎么扫描漏洞

注意:本文的下载教程,与以下文章的思路有相同点,也有不同点,最终目标只是让读者从多维度去熟练掌握本知识点。 下载教程: Python网络安全项目开发实战_扫描漏洞_编程案例解析实例详解课程教程.pdf Python在网络安全项目开发实战中扮演着至关重要的角色,尤其在漏洞扫描方面…

UFS Power Mode Change 介绍

一. UFS Power Mode Change简介 1.UFS Power Mode指的是Unipro层的Power State, 也可以称为链路(Link)上的Power Mode, 可以通过配置Unipro Attribute, 然后控制切换Unipro Power State, 当前Power Mode Change有两种触发方式: (1) 通过DME Power Mode Change触发…

从零实现GPT【1】——BPE

文章目录 Embedding 的原理训练特殊 token 处理和保存编码解码完整代码 BPE,字节对编码 Embedding 的原理 简单来说就是查表 # 解释embedding from torch.nn import Embedding import torch# 标准的正态分布初始化 也可以用均匀分布初始化 emb Embedding(10, 32) …

C++初学者指南第一步---13.聚合类型

C初学者指南第一步—13.聚合类型 文章目录 C初学者指南第一步---13.聚合类型1. 类型分类(简化)2. 如何定义和使用3. 为什么选择自定义类型/数据聚合?4. 聚合类型初始化5.混合6. 复制7. 值和引用的语义8.聚合的向量(std::vector)9.最令人烦恼的…

XSS+CSRF组合拳

目录 简介 如何进行实战 进入后台创建一个新用户进行接口分析 构造注入代码 寻找XSS漏洞并注入 小结 简介 (案例中将使用cms靶场来进行演示) 在实战中CSRF利用条件十分苛刻,因为我们需要让受害者点击我们的恶意请求不是一件容易的事情…

【MySQL数据库】:MySQL视图特性

目录 视图的概念 基本使用 准备测试表 创建视图 修改视图影响基表 修改基表影响视图 删除视图 视图规则和限制 视图的概念 视图是一个虚拟表,其内容由查询定义,同真实的表一样,视图包含一系列带有名称的列和行数据。视图中的数据…

vxe-table 列表过滤踩坑_vxe-table筛选

但是这个过滤输入值必须是跟列表的值必须一致才能查到,没做到模糊查询的功能,根据关键字来过滤并没有实现。 下面提供一下具体实现方法:(关键字来过滤) filterNameMethod({ option, row }) {if (row.name.indexOf(op…