在 PyTorch 中理解词向量,将单词转换为有用的向量表示

ops/2025/2/11 18:20:32/

你要是想构建一个大型语言模型,首先得掌握词向量的概念。幸运的是,这个概念很简单,也是本系列文章的一个完美起点。

那么,假设你有一堆单词,它可以只是一个简单的字符串数组。

animals = ["cat", "dog", "rat", "pig"]

你没法直接用单词进行数学运算,所以必须先把它们转换成数字。最简单的方法就是用它们在数组中的索引值。

animal_to_idx = {animal: idx for idx, animal in enumerate(animals)}

animal_to_idx

# Output:

# {'cat': 0, 'dog': 1, 'rat': 2, 'pig': 3}

当然,等你把数学运算做完,你还需要把索引转换回对应的单词。可以这样做:

idx_to_animal = {idx: animal for animal, idx in animal_to_idx.items()}

idx_to_animal

# Output:

# {0: 'cat', 1: 'dog', 2: 'rat', 3: 'pig'}

用索引来表示单词,在自然语言处理中一般不是个好主意。问题在于,索引会暗示单词之间存在某种顺序关系,而实际上并没有。

比如,我们的数据里,猫和猪之间并没有固有的关系,狗和老鼠之间也没有。但是,使用索引后,看起来猫离猪“很远”,而狗似乎“更接近”老鼠,仅仅因为它们在数组中的位置不同。这些数值上的距离可能会暗示一些实际上并不存在的模式。同样,它们可能会让人误以为这些动物之间存在基于大小或相似度的关系,而这在这里完全没有意义。

一个更好的方法是使用独热编码(one-hot encoding)。独热向量是一个数组,其中只有一个元素是 1(表示“激活”),其他所有元素都是 0。这种表示方式可以完全消除单词之间的错误排序关系。

让我们把单词转换成独热向量:

import numpy as np

n_animals = len(animals)

animal_to_onehot = {}

for idx, animal in enumerate(animals):

one_hot = np.zeros(n_animals, dtype=int)

one_hot[idx] = 1

animal_to_onehot[animal] = one_hot

animal_to_onehot

# Output:

# {

# 'cat': array([1, 0, 0, 0]),

# 'dog': array([0, 1, 0, 0]),

# 'rat': array([0, 0, 1, 0]),

# 'pig': array([0, 0, 0, 1])

# }

可以看到,现在单词之间没有任何隐含的关系了。

独热编码的缺点是,它是一种非常稀疏的表示,只适用于单词数量较少的情况。想象一下,如果你有 10,000 个单词,每个编码都会有 9,999 个零和一个 1,太浪费内存了,存那么多零干嘛……

**是时候创建更密集的向量表示了。换句话说,我们现在要做词向量(word embeddings)**了。

词向量是一种密集向量(dense vector),其中大多数(甚至所有)值都不是零。在机器学习,尤其是自然语言处理和推荐系统中,密集向量可以用来紧凑而有意义地表示单词(或句子、或其他实体)的特征。更重要的是,它们可以捕捉这些特征之间的有意义关系。

举个例子,我们创建一个词向量,其中每个单词用 2 个特征表示,而总共有 4 个单词。

用 PyTorch 创建词向量非常简单。我们只需要使用 nn.Embedding 层。你可以把它想象成一个查找表,其中行代表每个唯一单词,而列代表该单词的特征(即单词的密集向量)。

import torch

import torch.nn as nn

embedding_layer = nn.Embedding(num_embeddings=4, embedding_dim=2)

好,现在我们把单词的索引转换成词向量。这几乎不费吹灰之力,因为我们只需要把索引传给 nn.Embedding 层就行了。

indices = torch.tensor(np.arange(0, len(animals)))

indices

Output:

# tensor([0, 1, 2, 3])

embeddings = embedding_layer(indices)

embeddings

# Output:

# tensor([[ 1.6950, -2.7905],

# [ 2.4086, -0.1779],

# [ 0.7402, 0.0955],

# [-0.5155, 0.0738]], grad_fn=<EmbeddingBackward0>)

现在,我们可以用索引查看每个单词的词向量了。

for animal, _ in animal_to_idx.items():

print(f"{animal}'s embedding is {embeddings[animal_to_idx[animal]]}")

Output:

# cat's embedding is tensor([ 1.6950, -2.7905], grad_fn=<SelectBackward0>)

# dog's embedding is tensor([ 2.4086, -0.1779], grad_fn=<SelectBackward0>)

# rat's embedding is tensor([0.7402, 0.0955], grad_fn=<SelectBackward0>)

# pig's embedding is tensor([-0.5155, 0.0738], grad_fn=<SelectBackward0>)

每个单词都有两个特征——正是我们想要的结果。

目前这些数值没啥实际意义,因为 nn.Embedding 层还没有经过训练。但一旦它被适当地训练了,这些特征就会变得有意义。

注意:

这些特征对模型来说非常关键,但对人类来说可能永远不会“有意义”。它们代表的是通过训练学到的抽象特征。对我们来说,这些特征看起来可能是随机的、毫无意义的,但对一个训练好的模型来说,它们能够捕捉到重要的模式和关系,使其能够有效地理解和处理数据。

在本系列的下一篇文章中,我们将学习如何训练词向量模型。


http://www.ppmy.cn/ops/157578.html

相关文章

蓝桥杯51单片机练习(国信长天比赛用)

文章目录 代码实现头文件固定模板延时函数HC138译码器和或非门流水灯闪烁次数(假设闪烁5次)从左向右依次亮从左向右依次灭 总代码 代码实现 头文件 #include <REGX52.H> 固定模板 void main() { while(1) { } } 延时函数 void Delay(unsigned char t) { while(t–…

MySql --- 作业

一. 触发器 1建立两个表:goods(商品表)、orders(订单表) mysql> create database mydb16_tigger; Query OK, 1 row affected (0.01 sec)mysql> use mydb16_tigger; Database changed mysql> mysql> CREATE TABLE goods (-> gid CHAR(8) PRIMARY KEY,->…

spring cloud和spring boot的区别

Spring Cloud和Spring Boot在Java开发领域中都是非常重要的框架&#xff0c;但它们在目标、用途和实现方式上存在明显的区别。以下是对两者区别的详细解析&#xff1a; 1. 含义与定位 Spring Boot&#xff1a; 是一个快速开发框架&#xff0c;它简化了Spring应用的初始搭建以…

【Spring Boot】Spring 事务探秘:核心机制与应用场景解析

前言 ???本期讲解关于spring 事务介绍~~~ ??感兴趣的小伙伴看一看小编主页&#xff1a;-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.事务 ??1.1什么是事务 ??1.2为什么需要事务 ??1.3操作事务 ???…

【基于SprintBoot+Mybatis+Mysql】电脑商城项目之修改密码和个人资料

&#x1f9f8;安清h&#xff1a;个人主页 &#x1f3a5;个人专栏&#xff1a;【Spring篇】【计算机网络】【Mybatis篇】 &#x1f6a6;作者简介&#xff1a;一个有趣爱睡觉的intp&#xff0c;期待和更多人分享自己所学知识的真诚大学生。 目录 &#x1f383;1.修改密码 -持久…

【Ubuntu】安装和使用Ollama的报错处理集合

Ollama是一个开源的大型语言模型(LLM)推理服务器,为用户提供了灵活、安全和高性能的语言模型推理解决方案。 Ollama的主要特点是它能够运行多种类型的大型语言模型,包括但不限于Alpaca、Llama、Falcon、Mistral等,而无需将模型上传至服务器。这意味着用户可以直接在本地或…

Proxy vs DefineProperty

几年前校招面试的时候被问过一个问题&#xff0c;Vue3/Vue2 如何实现数据和UI的同步&#xff0c;其区别是什么&#xff0c;Vue3的方式优势是什么&#xff1f; 当时背了八股&#xff0c;默写了一通不知所云的代码&#xff0c;面试没过&#xff0c;再也没写过Vue。 今天拿出点时…

PromptSource官方文档翻译

目录 核心概念解析 提示模板&#xff08;Prompt Template&#xff09; P3数据集 安装指南 基础安装&#xff08;仅使用提示&#xff09; 开发环境安装&#xff08;需创建提示&#xff09; API使用详解 基本用法 子数据集处理 批量操作 提示创建流程 Web界面操作 手…