pytorch实现基于Word2Vec的词嵌入

ops/2025/2/7 10:25:49/

PyTorch 实现 Word2Vec(Skip-gram 模型) 的完整代码,使用 中文语料 进行训练,包括数据预处理、模型定义、训练和测试


1. 主要特点

支持中文数据,基于 jieba 进行分词
使用 Skip-gram 进行训练,适用于小数据集
支持负采样,提升训练效率
使用 cosine similarity 计算相似单词

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import jieba
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 超参数
window_size = 2      # 窗口大小
embedding_dim = 10   # 词向量维度
num_epochs = 100     # 训练轮数
learning_rate = 0.01 # 学习率
batch_size = 4       # 批大小
neg_samples = 5      # 负采样个数# 分词 & 构建词汇表
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 统计词频
word_counts = Counter([word for sentence in tokenized_corpus for word in sentence])
total_words = sum(word_counts.values())# 计算负采样概率
word_freqs = {word: count / total_words for word, count in word_counts.items()}
word_powers = {word: freq ** 0.75 for word, freq in word_freqs.items()}
Z = sum(word_powers.values())
word_distribution = {word: prob / Z for word, prob in word_powers.items()}# 负采样函数
def negative_sampling(positive_word, num_samples=5):words = list(word_distribution.keys())probabilities = list(word_distribution.values())negatives = []while len(negatives) < num_samples:neg = np.random.choice(words, p=probabilities)if neg != positive_word:negatives.append(neg)return negatives# 生成 Skip-gram 训练数据
data = []
for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):center_word = indices[center_idx]for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context_word = indices[context_idx]data.append((center_word, context_word))# 转换为 PyTorch 张量
data = [(torch.tensor(center), torch.tensor(context)) for center, context in data]# ========== 2. 定义 Word2Vec (Skip-gram) 模型 ==========
class Word2Vec(nn.Module):def __init__(self, vocab_size, embedding_dim):super(Word2Vec, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.output_layer = nn.Linear(embedding_dim, vocab_size)def forward(self, center_word):embed = self.embedding(center_word)  # 获取中心词向量out = self.output_layer(embed)       # 计算词分布return out# 初始化模型
model = Word2Vec(len(vocab), embedding_dim)# ========== 3. 训练 Word2Vec ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):total_loss = 0random.shuffle(data)  # 每轮打乱数据for center_word, context_word in data:optimizer.zero_grad()output = model(center_word.unsqueeze(0))  # 预测词分布loss = criterion(output, context_word.unsqueeze(0))  # 计算损失loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 测试词向量 ==========
word_vectors = model.embedding.weight.data.numpy()# 计算单词相似度
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = cosine_similarity(word_vec, word_vectors)[0]# 获取相似度最高的 top_n 个单词(排除自身)similar_idx = similarities.argsort()[::-1][1:top_n+1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试相似词
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

数据预处理
  • 使用 jieba.cut() 进行分词
  • 创建 word2idxidx2word
  • 使用滑动窗口生成 (中心词, 上下文词) 训练样本
  • 实现 negative_sampling() 提高训练效率
模型
  • Embedding 学习词向量
  • Linear 计算单词的概率分布
  • CrossEntropyLoss 计算目标词与预测词的匹配度
  • 使用 Adam 进行梯度更新
计算词相似度
  • 使用 cosine_similarity 计算词向量相似度
  • 找出 top_n 个最相似的单词

 5. 可优化点

 使用更大的中文语料库(如 THUCNews
 使用 t-SNE 进行词向量可视化
增加负采样,提升模型训练效率


http://www.ppmy.cn/ops/156413.html

相关文章

自定义多功能输入对话框:基于 Qt 打造灵活交互界面

一、引言 在使用 Qt 进行应用程序开发时&#xff0c;我们经常需要与用户进行交互&#xff0c;获取他们输入的各种信息。QInputDialog 是 Qt 提供的一个便捷工具&#xff0c;可用于简单的输入场景&#xff0c;但当需求变得复杂&#xff0c;需要支持更多类型的输入控件&#xff0…

使用 SurrealDB 构建高效的 GraphQL 后端

1. SurrealDB 简介 SurrealDB 是一款新兴的分布式多模型数据库&#xff0c;它结合了关系型数据库&#xff08;SQL 的强大查询能力&#xff09;与 NoSQL 数据库的灵活性&#xff0c;支持图数据库的复杂关系查询&#xff0c;同时具有内置的实时订阅功能。相比传统数据库&#xf…

linux的基础入门2

linux的root用户 无论是Windows、MacOS、Linux均采用多用户的管理模式进行权限管理。 在Linux系统中,拥有最大权限的账户名为:root(超级管理员) 而在前期&#xff0c;我们一直使用的账户是普通的用户 普通用户的权限&#xff0c;一般在其HOME目录内是不受限的 一旦出了HOME目录…

问卷数据分析|SPSS之分类变量描述性统计

1.点击分析--描述统计--频率 2. 选中分类变量&#xff0c;点击中间箭头 3.图表选中条形图&#xff0c;图表值选择百分比&#xff0c;选择确定 4.这里显示出了描述性统计的结果 5.下面就是图形&#xff0c;但SPSS画的图形都不是很好啊看&#xff0c;建议用其他软件画图&#xff…

Mac 部署Ollama + OpenWebUI完全指南

文章目录 &#x1f4bb; 环境说明&#x1f6e0;️ Ollama安装配置1. 安装[Ollama](https://github.com/ollama/ollama)2. 启动Ollama3. 模型存储位置4. 配置 Ollama &#x1f310; OpenWebUI部署1. 安装Docker2. 部署[OpenWebUI](https://www.openwebui.com/)&#xff08;可视化…

【大数据技术】搭建完全分布式高可用大数据集群(Scala+Spark)

搭建完全分布式高可用大数据集群(Scala+Spark) scala-2.13.16.tgzspark-3.5.4-bin-without-hadoop.tgz注:请在阅读本篇文章前,将以上资源下载下来。 写在前面 本文主要介绍搭建完全分布式高可用集群Spark的详细步骤。 注意: 统一约定将软件安装包存放于虚拟机的/softwa…

Java实战经验分享

1. 项目优化与性能提升 面试问题&#xff1a; 聊聊你印象最深刻的项目&#xff0c;或者做了哪些优化 你在项目中如何解决缓存穿透问题&#xff1f; 缓存穿透是我们做缓存优化时最常遇到的问题&#xff0c;特别是当查询的对象在数据库中不存在时&#xff0c;缓存层和数据库都会…

linux 使用docker安装 postgres 教程,踩坑实践

踩坑实践,安装好了不能远程访问。 防火墙已关闭、postgres 配置了允许所有ip 访问、网络是通的。端口也是开放的&#xff0c;就是不能用数据库链接工具访问。 最后发现是云服务器端口没开 ,将其打开 到这一步完全正确了&#xff0c;但是又报错了 关于连接PostgreSQL时提示 FA…