Pytorch的默认初始化分布 nn.Embedding.weight初始化分布

news/2025/3/15 1:11:34/

一、nn.Embedding.weight初始化分布

 

nn.Embedding.weight随机初始化方式是标准正态分布 [公式] ,即均值$\mu=0$,方差$\sigma=1$的正态分布。

 

论据1——查看源代码

 

## class Embedding具体实现(在此只展示部分代码)
import torch
from torch.nn.parameter import Parameterfrom .module import Module
from .. import functional as Fclass Embedding(Module):def __init__(self, num_embeddings, embedding_dim, padding_idx=None,max_norm=None, norm_type=2, scale_grad_by_freq=False,sparse=False, _weight=None):if _weight is None:self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))self.reset_parameters()else:assert list(_weight.shape) == [num_embeddings, embedding_dim], \'Shape of weight does not match num_embeddings and embedding_dim'self.weight = Parameter(_weight)def reset_parameters(self):self.weight.data.normal_(0, 1)if self.padding_idx is not None:self.weight.data[self.padding_idx].fill_(0)

 

 

Embedding这个类有个属性weight,它是torch.nn.parameter.Parameter类型的,作用就是存储真正的word embeddings。如果不给weight赋值,Embedding类会自动给他初始化,看上述代码第6~8行,如果属性weight没有手动赋值,则会定义一个torch.nn.parameter.Parameter对象,然后对该对象进行reset_parameters(),看第21行,对self.weight先转为Tensor在对其进行normal_(0, 1)(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight默认初始化方式就是N(0, 1)分布,即均值$\mu=0$,方差$\sigma=1$的标准正态分布。

 

论据2——简单验证nn.Embeddig.weight的分布

 

下面将做的是验证nn.Embeddig.weight某一行词向量的均值和方差,以便验证是否为标准正态分布。
注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。

 

import torch.nn as nn# dim越大,均值、方差越接近0和1
dim = 800000
# 定义了一个(5, dim)的二维embdding
# 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
# 每个词向量初始化为正态分布 N(0,1)(待验证)
embd = nn.Embedding(5, dim)
# type(embd.weight) is Parameter
# type(embd.weight.data) is Tensor
# embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
weight = embd.weight.data[0].numpy()
print("weight: {}".format(weight))weight_sum = 0
for w in weight:weight_sum += w
mean = weight_sum / dim
print("均值: {}".format(mean))square_sum = 0
for w in weight:square_sum += (mean - w) ** 2
print("方差: {}".format(square_sum / dim))

 

 

代码输出:

 

weight: [-0.65507996  0.11627434 -1.6705967  ...  0.78397447  ...  -0.13477565]
均值: 0.0006973597864689242
方差: 1.0019535550544454

 

 

可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight是标准正态分布$N(0, 1)$。

 

二、torch.Tensortorch.tensortorch.randn初始化分布

 

1、torch.rand

 

返回$[0,1)$上的均匀分布(uniform distribution)。

 

2、torch.randn

 

返回$N(0, 1)$,即标准正态分布(standard normal distribution)。

 

3、torch.Tensor

 

torch.Tensor是Tensor class,torch.Tensor(2, 3)是调用Tensor的构造函数,构造了$2\times3$矩阵,但是没有分配空间,未初始化。
不推荐使用torch.Tensor创建Tensor,应使用torch.tenstortorch.onestorch.zerostorch.randtorch.randn等,原因:

 

t = torch.Tensor(2,3)
# 容易出现下述错误,因为t中的值取决当前内存中的随机值
# 如果当前内存中随机值特别大会溢出
RuntimeError: Overflow when unpacking long
 

http://www.ppmy.cn/news/606230.html

相关文章

小芯片与大芯片技术

小芯片与大芯片技术 芯片尺寸构装(Chip Scale Package, CSP)是一种半导体构装技术。。作为新一代的芯片封装技术,在TSOP、BGA的基础上,CSP的性能又有了革命性的提升。 CSP,全称为Chip Scale Package,即芯片…

LeetCode简单题之两个列表的最小索引总和

题目 假设 Andy 和 Doris 想在晚餐时选择一家餐厅,并且他们都有一个表示最喜爱餐厅的列表,每个餐厅的名字用字符串表示。 你需要帮助他们用最少的索引和找出他们共同喜爱的餐厅。 如果答案不止一个,则输出所有答案并且不考虑顺序。 你可以假…

apt命令概述,apt命令在Ubuntu16.04安装openjdk-7-jdk

apt是一条linux命令,适用于deb包管理式操作系统,主要用于自动从互联网的软件仓库中搜索、安装、升级、卸载软件或操作系统。deb包是Debian 软件包格式的文件扩展名。 翻译过来就是: apt是一个命令行包管理器,为 搜索和管理以及查询…

LeetCode简单题之数组形式的整数加法

题目 对于非负整数 X 而言,X 的数组形式是每位数字按从左到右的顺序形成的数组。例如,如果 X 1231,那么其数组形式为 [1,2,3,1]。 给定非负整数 X 的数组形式 A,返回整数 XK 的数组形式。 示例 1: 输入:A…

Linux学习(3)——安装vmtools

安装vmtools vmtools安装后,我们可以在Windows下更好的管理vm虚拟机可以设置Windows和centos的共享文件夹 安装步骤进入centos点击vm菜单的–>install vmware toolscentos会 出现一个vm的安装包,xx.tar.gz拷贝到/opt使用解压命令tar,得到一个安装文件 cd /opt/[…

从DPU开始到RDMA到CUDA

从DPU开始到RDMA到CUDA DPU是Data Processing Unit的简称,它是最新发展起来的专用处理器的一个大类,是继CPU、GPU之后,数据中心场景中的第三颗重要的算力芯片,为高带宽、低延迟、数据密集的计算场景提供计算引擎。 DPU将作为CPU的…

LeetCode简单题之棒球比赛

题目 你现在是一场采用特殊赛制棒球比赛的记录员。这场比赛由若干回合组成,过去几回合的得分可能会影响以后几回合的得分。 比赛开始时,记录是空白的。你会得到一个记录操作的字符串列表 ops,其中 ops[i] 是你需要记录的第 i 项操作&#xf…

数据结构:二叉树的非递归遍历

二叉树的前序遍历题目描述思路代码二叉树的中序遍历题目描述思路代码二叉树的后序遍历题目描述思路代码二叉树的层序遍历题目描述前提知识代码二叉树的前序遍历 144.二叉树的前序遍历 题目描述 给你二叉树的根节点 root ,返回它节点值的 前序 遍历。 提示&#…