Pytorch从0复现worc2vec skipgram模型及fasttext训练维基百科语料词向量演示

server/2024/11/13 2:56:31/

目录

Skipgram架构

代码开源声明

Pytorch复现Skip-gram

导包及随机种子设置

维基百科数据读取

建立词频元组列表并根据词频排序

建立词频字典,word_id字典,id_word字典

二次采样

正采样与负采样

Skipgram模型类

模型训练

词向量输出

近义词寻找

fasttext训练Skip-gram


Skipgram架构

初始论文中理论实现中,训练了两个参数矩阵,Word2vec中可以拆解为为词向量的降维矩阵和升维矩阵,初始使用独热编码对token进行序列标注,有图可以看出,由3*5的参数矩阵左乘5*1的词向量可以得到3*1的降维后的词向量,然后再由5*3的参数矩阵对降维后的词向量进行升维,与要预测的token进行损失计算

在实际实现中会采用隐式独热编码,也就是并不会手动通过独热编码进行词向量索引,比如语料库总共有5个token,对其进行独热编码后,由3维的独热编码来表示5个token,以下演示通过独热编码索引出词向量矩阵中对应token的词向量

import numpy as npnp.random.seed(0)y = np.array([1,0,0,0,0])
x = np.random.randn(5,3)
print(y)
print(x)
print(np.dot(y,x))
# [1 0 0 0 0]
# [[ 1.76405235  0.40015721  0.97873798]
#  [ 2.2408932   1.86755799 -0.97727788]
#  [ 0.95008842 -0.15135721 -0.10321885]
#  [ 0.4105985   0.14404357  1.45427351]
#  [ 0.76103773  0.12167502  0.44386323]]
# [1.76405235 0.40015721 0.97873798]

初始独热编码为1 0 0 0 0,通过左乘词向量矩阵可以索引到词向量矩阵的第一行,也就是一个token的词向量,

Word2vec一般分为Cbow以及Skip-gram,Skip-gram主要通过中间的token预测两侧的token,Skip-gram则是通过两侧的token预测中间的token.在理论实现中,例如Skip-gram就是通过取出中间 token的降维后的词向量再对其通过升维矩阵进行向量升维,与两侧token的原始独热编码进行损失计算

本文将进行Skip-gram的pytorch复现

在实际编码实现与理论实现具有一些区别,首先,实际编码实现中并不会显示创建独热编码进行词向量索引,而是直接通过embedding层来实现词向量矩阵的初始化和训练.

在理论实现上的损失计算是通过升维矩阵来进行与中间token的独热编码的损失计算

在实际实现上则有所不同

1.Skip-gram是通过中间预测两侧的结果,在实际是通过降维后的中间token的词向量,然后使用另一个降维矩阵对两侧token进行降维运算,最后通过降维后的中间token词向量和降维后的两侧token的词向量进行点乘计算用于计算相似度

2.对于token间的相似度,在实际实现中采用滑块的方式,我们会按序在语料中选择中间token(center_token),然后通过设置滑动窗口来进行两侧词的获取,

3.实际实现中还进行了负采样,也就是在中心token与相邻token进行点乘计算时,通常中心词与相邻token具有较高相似度,也就是点乘结果会越大,而与较远的token的相似度较低,点乘的结果也就会越小,在第2点中,提到的滑块就是用于选取相邻token的实现方式

3.在实际实现中可以选择性实现二次采样,用于随机删除高频词,因为高频词可能会对低频词的词向量学习产生影响

代码开源声明

本文包含的所有代码,数据集及训练完成的模型权重都可在下方的github链接中找到,如有需要使用训练好的模型权重及完整代码,可通过下方链接下载:

GitHub - Foxbabe1q/Pytorch_skipgram: Use pytorch to define skipgram model to train with wikipedia corpus. And I also use fasttext's skipgram to train the corpus

Pytorch复现Skip-gram

导包及随机种子设置

import io
import os
import sys
import requests
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pdnp.random.seed(42)
torch.manual_seed(42)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

维基百科数据读取

def load_data():with open('fil9','r') as f:data = f.read()print(data[:100])corpus = data.split()print(corpus[:100])return corpusif __name__ == '__main__':corpus = load_data()#  anarchism originated as a term of abuse first used against early working class radicals including t
# ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst', 'the', 'term', 'is', 'still', 'used', 'in', 'a', 'pejorative', 'way', 'to', 'describe', 'any', 'act', 'that', 'used', 'violent', 'means', 'to', 'destroy', 'the', 'organization', 'of', 'society', 'it', 'has', 'also', 'been', 'taken', 'up', 'as', 'a', 'positive', 'label', 'by', 'self', 'defined', 'anarchists', 'the', 'word', 'anarchism', 'is', 'derived', 'from', 'the', 'greek', 'without', 'archons', 'ruler', 'chief', 'king', 'anarchism', 'as', 'a', 'political', 'philosophy', 'is', 'the', 'belief', 'that', 'rulers', 'are', 'unnecessary', 'and', 'should', 'be', 'abolished', 'although', 'there', 'are', 'differing']

建立词频元组列表并根据词频排序

def build_word_freq_tuple(corpus):word_freq_dict = {}for word in corpus:if word in word_freq_dict:word_freq_dict[word] += 1elif word not in word_freq_dict:word_freq_dict[word] = 1word_freq_tuple = sorted(word_freq_dict.items(), key=lambda x: x[1], reverse=True)print(word_freq_tuple[:10])return word_freq_tupleif __name__ == '__main__':corpus = load_data()word_freq_tuple = build_word_freq_tuple(corpus)# [('the', 7446708), ('of', 4453926), ('one', 3776770), ('zero', 3085174), ('and', 2916968), ('in', 2480552), ('two', 2339802), ('a', 2241744), ('nine', 2063649), ('to', 2028129)]

建立词频字典,word_id字典,id_word字典

def convert_corpus_id(corpus, word_id_dict):id_corpus = []for word in corpus:id_corpus.append(word_id_dict[word])print('corpus_size: ', len(id_corpus))print(id_corpus[:20])return id_corpusif __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)id_corpus = convert_corpus_id(corpus, word_id_dict)# vocabulary size:  833184
# corpus_size:  124301826
# [9558, 3423, 19, 7, 277, 1, 3451, 56, 82, 208, 174, 781, 500, 9838, 187, 0, 28373, 1, 0, 179]

这里可以看到语料总长度达到了1亿多词数,但是这个数量级的语料仍然较少,之后介绍的二次采样可以酌情选择是否选择,在语料较为不足的时候,二次采样可能产生相反效果

二次采样

二次采样用于通过删除一定数量的高频词来更好地训练低频词的词向量,公式如下

P(w_i) = \max \left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right)

这里的f(w_i)指的是词频除总词数,t是一个阈值,通常为1e-5,t设置的越大,被删除的概率越小

P(w_i)为被删除的概率

def subsampling(corpus, word_freq_dict):corpus = [word for word in corpus if not np.random.rand() < (1 - (np.sqrt(1e-5 * len(corpus) / word_freq_dict[word])))]print('corpus_size after subsampling: ', len(corpus))return corpusif __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)corpus = subsampling(corpus, word_freq_dict)# corpus_size:  124301826
# vocabulary size:  833184
# corpus_size after subsampling:  83240619

正采样与负采样

def build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size = 10, max_window_size = 3):dataset = []for center_word_idx, center_word in enumerate(corpus):window_size = np.random.randint(1, max_window_size+1)positive_range = (max(0, center_word_idx - window_size), min(len(corpus) - 1, center_word_idx + window_size))positive_samples = [corpus[word_idx] for word_idx in range(positive_range[0], positive_range[1]+1) if word_idx != center_word_idx]for positive_sample in positive_samples:dataset.append((center_word, positive_sample, 1))sample_idx_list = np.arange(len(word_id_dict))j = corpus[positive_range[0]: positive_range[1]+1]sample_idx_list = np.delete(sample_idx_list, j)negative_samples = np.random.choice(sample_idx_list, size=negative_sample_size, replace=False)for negative_sample in negative_samples:dataset.append((center_word, negative_sample, 0))print('20 samples of the dataset')for i in range(20):print('center_word:', id_word_dict[dataset[i][0]], 'target_word:', id_word_dict[dataset[i][1]], 'label',dataset[i][2])return datasetif __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)corpus = subsampling(corpus, word_freq_dict)corpus = convert_corpus_id(corpus, word_id_dict)dataset = build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size = 10)# 20 samples of the dataset
# center_word: originated target_word: working label 1
# center_word: originated target_word: class label 1
# center_word: originated target_word: gulfs label 0
# center_word: originated target_word: propenents label 0
# center_word: originated target_word: pelletier label 0
# center_word: originated target_word: exclaiming label 0
# center_word: originated target_word: bod label 0
# center_word: originated target_word: liturgical label 0
# center_word: originated target_word: quattro label 0
# center_word: originated target_word: anatolius label 0
# center_word: originated target_word: interstratified label 0
# center_word: originated target_word: das label 0
# center_word: working target_word: originated label 1
# center_word: working target_word: class label 1
# center_word: working target_word: radicals label 1
# center_word: working target_word: clip label 0
# center_word: working target_word: moulting label 0
# center_word: working target_word: gnomon label 0
# center_word: working target_word: neural label 0
# center_word: working target_word: marsupial label 0

这里正采样选择中心词周围至多6个词作为与中心词语义强相关的词,而在其它词中随机挑选10个词用于负采样,强相关label为1,负相关label为0

Skipgram模型类

class SkipGram(nn.Module):def __init__(self, vocab_size, embedding_size):super(SkipGram, self).__init__()self.vocab_size = vocab_sizeself.embedding_size = embedding_sizeself.embedding = nn.Embedding(self.vocab_size, self.embedding_size)self.out_embedding = nn.Embedding(self.vocab_size, self.embedding_size)init_range = (1 / embedding_size) ** 0.5nn.init.uniform_(self.embedding.weight, -init_range, init_range)nn.init.uniform_(self.out_embedding.weight, -init_range, init_range)def forward(self, center_idx, target_idx, label):center_embedding = self.embedding(center_idx)target_embedding = self.embedding(target_idx)sim = torch.mul(center_embedding, target_embedding)sim = torch.sum(sim, dim=1, keepdim=False)loss = F.binary_cross_entropy_with_logits(sim, label,reduction='sum')return loss

这里使用第一个embedding矩阵作为最后的词向量矩阵,并且训练相关性使用词向量点乘值作为指标 

模型训练

def train(vocab_size, dataset):my_skipgram = SkipGram(vocab_size = vocab_size, embedding_size=300)my_skipgram.to(device)my_dataset = create_dataset(dataset)my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)optimizer = optim.Adam(my_skipgram.parameters(), lr=0.001)epochs = 10loss_list = []start_time = time.time()for epoch in range(epochs):total_loss = 0total_sample = 0for center_idx, target_idx, label in my_dataloader:loss = my_skipgram(center_idx, target_idx, label)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()total_sample += len(center_idx)print(f'epoch: {epoch+1}, loss = {total_loss/total_sample}, time = {time.time() - start_time : .2f}')loss_list.append(total_loss/total_sample)plt.plot(np.arange(1, epochs + 1),loss_list)plt.title('Loss_curve')plt.xlabel('Epoch')plt.ylabel('Loss')plt.xticks(np.arange(1, epochs + 1))plt.savefig('loss_curve.png')plt.show()torch.save(my_skipgram.state_dict(), 'skip_gram.pt')if __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)corpus = subsampling(corpus, word_freq_dict)corpus = convert_corpus_id(corpus, word_id_dict)dataset = build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size = 10)train(len(word_id_dict), dataset)

这里训练只训练了10个epoch,并且为了节约训练资源,由于原语料长度超过一亿,所以这里只选取长度为200万的语料进行训练

词向量输出

def predict(word, vocab_size, word_id_dict):if word not in word_id_dict:print(f"Word '{word}' not found in the vocabulary.")return Nonemy_skipgram = SkipGram(vocab_size = vocab_size, embedding_size=300)my_skipgram.load_state_dict(torch.load('skip_gram.pt'))my_skipgram.to(device)my_skipgram.eval()word_id = torch.tensor(word_id_dict[word], device=device, dtype=torch.int64)print(f"Predicting the embedding vector for word '{word}':\n{my_skipgram.embedding(word_id)}")if __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)corpus = subsampling(corpus, word_freq_dict)corpus = convert_corpus_id(corpus, word_id_dict)dataset = build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size = 10)train(len(word_id_dict), dataset)predict('sport', len(word_id_dict), word_id_dict)

近义词寻找

def similarity(word, vocab_size, word_id_dict, id_word_dict, neighbors = 5):if word not in word_id_dict:print(f"Word '{word}' not found in the vocabulary.")return Nonemy_skipgram = SkipGram(vocab_size=vocab_size, embedding_size=300)my_skipgram.load_state_dict(torch.load('skip_gram.pt', weights_only=True))my_skipgram.to(device)my_skipgram.eval()word_id = torch.tensor(word_id_dict[word], device=device, dtype=torch.int64)word_embedding = my_skipgram.embedding(word_id)similarity_score = {}for idx in word_id_dict.values():other_word_embedding = my_skipgram.embedding(torch.tensor(idx, device=device, dtype=torch.int64))sim = torch.matmul(word_embedding, other_word_embedding)/(torch.norm(word_embedding, dim=0, keepdim=False) * torch.norm(other_word_embedding, dim=0, keepdim=False))similarity_score[id_word_dict[idx]] = sim.item()nearest_neighbors = sorted(similarity_score.items(), key=lambda x: x[1], reverse=True)[:5]print(nearest_neighbors)return nearest_neighborsif __name__ == '__main__':corpus = load_data()word_freq_dict, word_id_dict, id_word_dict = build_word_id_dict(corpus)corpus = subsampling(corpus, word_freq_dict)corpus = convert_corpus_id(corpus, word_id_dict)dataset = build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size = 10)train(len(word_id_dict), dataset)predict('sport', len(word_id_dict), word_id_dict)similarity('sport', len(word_id_dict), word_id_dict, id_word_dict, neighbors = 5)

这里查找近义词,会从词典中找到点乘值最大的5个词,个数可以通过修改neighbors更改

fasttext训练Skip-gram

fasttext训练的过程较为简单,该模型,包括还有CBOW都被集成在了模块中

import fasttextdef train():skipgram = fasttext.train_unsupervised('fil9', model = 'skipgram')skipgram.save_model('skipgram.bin')def skg_test1():skipgram = fasttext.load_model('skipgram.bin')print(skipgram.get_word_vector('sport'))print(skipgram.get_nearest_neighbors('sport'))if __name__ == '__main__':train()skg_test1()# Read 124M words
# Number of words:  218316
# Number of labels: 0
# Progress: 100.0% words/sec/thread:   38918 lr:  0.000000 avg.loss:  1.071778 ETA:   0h 0m 0s
# Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.
# [-1.1217905e-01 -2.1082790e-01 -5.0111616e-05 -7.6881155e-02
#  -2.0150667e-01 -1.8065287e-01  1.3297442e-01  1.3444095e-02
#  -1.5131533e-01 -2.5561339e-01  1.5086566e-01 -8.5557923e-02
#  -2.1246003e-01 -8.0699474e-02 -1.5511900e-01 -2.4630783e-01
#   4.1686368e-01  8.0300289e-01  2.5104052e-01 -7.7809072e-01
#   2.2462079e-01  8.2177565e-02  1.7808667e-01 -3.3937061e-01
#   1.2025767e-01  9.7873092e-02 -3.8934144e-01  1.2671056e-01
#  -2.7373591e-01  4.1039872e-01 -2.9629371e-01  4.4961619e-01
#   5.0581735e-02 -1.9909970e-01  1.0461334e-01 -4.9297757e-02
#  -9.5666438e-02  1.6832566e-01  7.4807540e-02  6.5610033e-01
#  -2.6710102e-01  2.5174522e-01  2.0871958e-01 -2.3539853e-01
#  -1.0441781e-01 -3.5934374e-01 -2.0167212e-01 -6.7970419e-01
#  -4.6956554e-02  9.3441598e-02  3.8153380e-01  2.0482899e-01
#   6.1529225e-01 -9.8463172e-01 -5.7401802e-02 -1.5414989e-01
#   6.7769766e-02  2.2661546e-01 -3.1193841e-02  3.8101819e-01
#  -3.1099179e-01 -2.9264178e-02  2.0313324e-01 -3.6542088e-01
#  -1.2520532e-01  1.8720575e-01 -2.6330149e-01  1.9312735e-01
#  -5.1107663e-01 -2.5122452e-01  2.2448047e-01 -4.7734442e-01
#   2.5731093e-01 -1.4026532e-01  4.3919176e-02 -2.0015708e-01
#  -2.8174376e-01  3.3095101e-01  1.0486527e-01  2.8560793e-01
#  -2.4086323e-01 -9.3831137e-02 -1.9629408e-01  2.4319877e-01
#  -1.8636097e-01 -3.9179447e-01  7.6361425e-02  1.6013722e-01
#  -9.0249017e-02 -5.6596959e-01  4.8584041e-01  3.4663376e-01
#   2.6066643e-01 -7.1866415e-03  1.7896013e-01 -1.2109153e+00
#  -7.9120353e-02  7.6195911e-02  4.5524022e-01 -1.4492531e-01]
# [(0.849130392074585, 'sports'), (0.8167348504066467, 'sporting'), (0.8091928362846375, 'competitions'), (0.7699509859085083, 'racing'), (0.7655908465385437, 'sportsman'), (0.7654882073402405, 'bobsledding'), (0.7621665000915527, 'bobsleigh'), (0.7620510458946228, 'motorsport'), (0.7576955556869507, 'korfball'), (0.7561532258987427, 'competiting')]


http://www.ppmy.cn/server/140811.html

相关文章

就是这个样的粗爆,手搓一个计算器:完美平方三项式计算器

作为程序员&#xff0c;没有合适的工具&#xff0c;就得手搓一个&#xff0c;PC端&#xff0c;移动端均可适用。废话不多说&#xff0c;直接上代码。 HTML: <div class"calculator"><div class"input-group"><label for"termA"…

electron 中 contextBridge 作用

1. 安全地实现渲染进程和主进程之间的通信 在 Electron 应用中&#xff0c;主进程和渲染进程是相互隔离的&#xff0c;这是为了安全和稳定性考虑。 然而&#xff0c;在很多情况下&#xff0c;渲染进程需要访问主进程中的某些功能&#xff0c;例如系统级别的操作或者一些应用级…

Oracle RAC的thread

参考文档&#xff1a; Real Application Clusters Administration and Deployment Guide 3 Administering Database Instances and Cluster Databases Initialization Parameter Use in Oracle RAC Table 3-3 Initialization Parameters Specific to Oracle RAC THREAD Sp…

陪诊问诊APP开发实战:基于互联网医院系统源码的搭建详解

时下&#xff0c;开发一款功能全面、用户体验良好的陪诊问诊APP成为了医疗行业的一大热点。本文将结合互联网医院系统源码&#xff0c;详细解析陪诊问诊APP的开发过程&#xff0c;为开发者提供实用的开发方案与技术指导。 一、陪诊问诊APP的背景与功能需求 陪诊问诊APP核心目…

i2c-tools 4.3 for Android 9.0

i2c-tools 4.3 Android 9.0下编译i2c-tools 4.3 下载源码 cd external wget https://mirrors.edge.kernel.org/pub/software/utils/i2c-tools/i2c-tools-4.3.tar.gz tar -zxvf i2c-tools-4.3.tar.gz 增加Android.mk Android.mk LOCAL_PATH : $(call my-dir)include $(C…

C#封装EPPlus库为Excel导出工具

1&#xff0c;添加NUGet包 2&#xff0c;封装工具类 using OfficeOpenXml; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection;namespace GMWPF.utils {public class ExcelUtil<T>{/// <summary>///…

huggingface 下载方法 测试ok

目录 设置环境变量 python 下载无token&#xff1a; python 下载加token 常见报错 设置环境变量 Linux export HF_ENDPOINThttps://hf-mirror.comWindows Powershell $env:HF_ENDPOINT "https://hf-mirror.com" pip install -U huggingface_hub huggingfac…

Rust闭包(能够捕获周围作用域变量的匿名函数,广泛应用于迭代、过滤和映射)闭包变量三种捕获方式:通过引用(不可变引用)、通过可变引用和通过值(取得所有权)

文章目录 Rust 闭包详解闭包的定义与语法基本语法 闭包的特性- 环境捕获&#xff08;三种捕获方式&#xff1a;通过引用、通过可变引用和通过值&#xff08;取得所有权&#xff09;&#xff09;示例代码 - 内存安全与生命周期示例代码1 示例代码2&#xff1a;闭包所有权转移示例…