pytorch基于FastText实现词嵌入

ops/2025/2/4 22:42:25/

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)
  • 模型定义
  • 训练
  • 测试
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
import random# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]# 构建 n-grams
def generate_ngrams(words, n=3):ngrams = []for word in words:ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]return ngrams# 生成 n-grams 词表
all_ngrams = set()
for sentence in tokenized_corpus:for word in sentence:all_ngrams.update(generate_ngrams(word))# 构建词汇表
vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 构建训练数据(CBOW 方式)
window_size = 2
data = []for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):context = []for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context.append(indices[context_idx])if context:data.append((context, indices[center_idx]))  # (上下文, 目标词)# ========== 2. 定义 FastText 模型 ==========
class FastText(nn.Module):def __init__(self, vocab_size, embedding_dim):super(FastText, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.linear = nn.Linear(embedding_dim, vocab_size)def forward(self, context):context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量output = self.linear(context_vec)return output# 初始化模型
embedding_dim = 10
model = FastText(len(vocab), embedding_dim)# ========== 3. 训练 FastText ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100for epoch in range(num_epochs):total_loss = 0random.shuffle(data)for context, target in data:context = torch.tensor([context], dtype=torch.long)target = torch.tensor([target], dtype=torch.long)optimizer.zero_grad()output = model(context)loss = criterion(output, target)loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 获取词向量 ==========
word_vectors = model.embeddings.weight.data.numpy()# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = np.dot(word_vectors, word_vec.T).squeeze()similar_idx = similarities.argsort()[::-1][1:top_n + 1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)
  • 窗口大小 = 2,即:
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")
    

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式: 
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

 4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性FastTextWord2VecGloVe
原理预测中心词 + n-grams预测中心词或上下文统计词共现信息
未登录词处理可处理无法处理无法处理
训练速度 快
适合领域中文、罕见词传统 NLP大规模数据

http://www.ppmy.cn/ops/155694.html

相关文章

Resnet 改进:尝试在不同位置加入Transform模块

目录 1. TransformerBlock 2. resnet 3. 替换部分卷积层 4. 在特定位置插入Transformer模块 5. 使用Transformer全局特征提取器 6. 其他 Tips:融入模块后的网络经过测试,可以直接使用,设置好输入和输出的图片维度即可 1. TransformerBlock TransformerBlock是Transfo…

【Qt】信号和槽简介

信号与槽是 Qt编程的基础&#xff0c;也是 Qt的一大创新。有了信号与槽的编程机制&#xff0c;在 Qt 中处理界面上各个组件的交互操作就变得比较直观和简单。 信号(signal)是在特定情况下被发射的通知&#xff0c;例如QPushButton较常见的信号就是点击鼠标时发射的 clicked()信…

图像处理之图像灰度化

目录 1 图像灰度化简介 2 图像灰度化处理方法 2.1 均值灰度化 2.2 经典灰度化 2.3 Photoshop灰度化 2.4 C语言代码实现 3 演示Demo 3.1 开发环境 3.2 功能介绍 3.3 下载地址 参考 1 图像灰度化简介 对于24位的RGB图像而言&#xff0c;每个像素用3字节表示&#xff0…

【数据库初阶】表的查询语句和聚合函数

&#x1f389;博主首页&#xff1a; 有趣的中国人 &#x1f389;专栏首页&#xff1a; 数据库初阶 &#x1f389;其它专栏&#xff1a; C初阶 | C进阶 | 初阶数据结构 亲爱的小伙伴们&#xff0c;大家好&#xff01;在这篇文章中&#xff0c;我们将深入浅出地为大家讲解 表的查…

视频拼接,拼接时长版本

目录 视频较长&#xff0c;分辨率较大&#xff0c;这个效果很好&#xff0c;不耗用内存 ffmpeg imageio&#xff0c;适合视频较短 视频较长&#xff0c;分辨率较大&#xff0c;这个效果很好&#xff0c;不耗用内存 ffmpeg import subprocess import glob import os from nats…

ubuntu18.04环境下,Zotero 中pdf translate划线后不翻译问题解决

问题&#xff1a; 如果使用fastgithub&#xff0c;在/etc/profile中设置全局代理&#xff0c;系统重启后会产生划线后不翻译的问题&#xff0c;包括所有翻译代理均不行。终端中取消fastgithub代理&#xff0c;也不行。 解决&#xff1a; 1&#xff09;不在/etc/profile中设置…

GESP2023年9月认证C++六级( 第三部分编程题(1)小杨买饮料)

参考程序&#xff1a; #include <iostream> using namespace std;const int INF 1000000000; // 定义一个大的数&#xff0c;代表无解的状态int cost[2001]; // 用来存储每个容量所需的最小花费int main() {int N 0, L 0;cin >> N >> L; // 读入饮料种…

安卓pad仿写element-ui表单验证

安卓pad仿写element-ui表单验证 背景&#xff1a;仿写对象&#xff1a;showcase&#xff1a; 布局总览&#xff1a;核心代码总览&#xff1a;代码仓&#xff1a; 背景&#xff1a; 最近半年开始接触安卓开发&#xff0c;平时开发接触的点比较零碎&#xff0c;计划闲暇时做一些…