bert实现词嵌入及其参数详解

news/2024/11/8 6:45:03/

实现步骤

  1. 加载BERT预训练好的模型和tokenizer
    如果你已经将bert的预训练模型下载到本地,那么你可以从本地加载
tokenizer = BertTokenizer.from_pretrained('/home/wu/david/demo/bert_base_chinese')
model = BertModel.from_pretrained('/home/wu/david/demo/bert_base_chinese')

如果没有下载,也可以在线加载

tokenizer = BertTokenizer.from_pretrained('bert_base_chinese')
model = BertModel.from_pretrained('bert_base_chinese')
  1. 使用tokenizer.encode_plus将每个文本转换为BERT输入格式
inputs = tokenizer.encode_plus(text,add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)

解释一下每个参数的意义:

  • text:要编码的文本字符串。
  • add_special_tokens:指示是否在编码中添加特殊令牌(如 [CLS] 和 [SEP])。默认为 True。在BERT中,特殊令牌用于标识句子的开头和结尾,以及句子对任务中的分隔符。
  • max_length:编码后的最大长度。如果文本长度超过此值,将进行截断。默认为 None,表示不进行截断。
  • padding:指示是否对编码进行填充,使其达到 max_length。默认为 False。如果设置为 ‘max_length’,则会将文本填充到 max_length 的长度。填充通常在批处理中使用,确保批处理中的所有样本具有相同的长度。
  • truncation:指示是否对文本进行截断,使其达到 max_length。默认为 False。如果设置为 True,则会将文本截断到 max_length 的长度。
  • return_attention_mask:指示是否返回注意力掩码。默认为 False。如果设置为 True,则会返回一个表示注意力掩码的张量。注意力掩码标识出哪些令牌是真实输入,哪些是填充的,以便模型在处理时忽略填充令牌。
  • return_tensors:指定返回的编码结果的张量类型。可选值为 “tf”(返回 TensorFlow 张量)或 “pt”(返回 PyTorch 张量)。默认为 None,表示返回标准 Python 列表。

inputs返回值有三部分:

  • input_ids:编码后的文本的令牌 ID 列表,可以被模型直接使用。
  • attention_mask:(可选)如果 return_attention_mask=True,则返回一个表示注意力掩码的列表。该列表标识出编码中哪些令牌应该被模型关注,哪些应该被忽略。
  • token_type_ids:(可选)如果 return_token_type_ids=True,则返回一个表示句子类型的列表。在句对任务中,它标识出不同句子的令牌属于哪个句子。
  1. 使用BERT模型进行嵌入
        #调用模型,传入上一步得到的inputsoutputs = model(**inputs)# 获取最后一个隐藏层表示,其形状为 [batch_size, sequence_length, hidden_size]last_hidden_state = outputs.last_hidden_state print(last_hidden_state.shape)token_embeddings = torch.squeeze(last_hidden_state, dim=0) #去掉batch_size这个维度,得到BERT模型编码后的每个令牌的隐藏状态表示

注:

  • last_hidden_state 代表了BERT模型对输入文本的编码表示,其维度为[batch_size, sequence_length, hidden_size],其中,batch_size表示批处理中的样本数量;sequence_length表示每个样本的序列长度,即输入文本的令牌数量;hidden_size表示BERT模型的隐藏状态的维度大小,通常是预训练模型的参数之一。
  • torch.squeeze() 是一个PyTorch函数,用于从张量中移除维度为1的维度,因此,token_embeddings 是经过挤压后的最后一个隐藏状态张量,形状为 [sequence_length, hidden_size],表示BERT模型编码后的每个令牌的隐藏状态表示

完整代码

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import pandas as pd
from torch.utils.data import Dataset,DataLoader# 本地加载BERT模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('/home/wu/david/demo/bert_base_chinese')
model = BertModel.from_pretrained('/home/wu/david/demo/bert_base_chinese')# 在线加载BERT模型和tokenizer
# tokenizer = BertTokenizer.from_pretrained('bert_base_chinese')
# model = BertModel.from_pretrained('bert_base_chinese')data=["欢迎学习bert!","万丈高楼平地起,成功只能靠自己","希望我们都有一个光明的未来!"]for index,text in enumerate(data):print("对第",index,"个文本进行词嵌入...")# tokenizer.encode_plus负责将每个文本转换为BERT输入格式inputs = tokenizer.encode_plus(text,add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)with torch.no_grad():# 使用BERT模型进行嵌入outputs = model(**inputs)# 获取最后一个隐藏层表示,其形状为 [batch_size, sequence_length, hidden_size]last_hidden_state = outputs.last_hidden_stateprint(last_hidden_state.shape)token_embeddings = torch.squeeze(last_hidden_state, dim=0)

http://www.ppmy.cn/news/172103.html

相关文章

微信小程序笔记 整理

微信小程序笔记 整理 地址: https://mp.weixin.qq.com/ 小程序的基本构成 pages --> 用来存放所有小程序的页面 页面以文件夹的格式保存在pages里,每个页面由四个基本文件组成 xx.js --> 页面数据 和 逻辑处理的文件xx.json --> 页面的配置 ( 如果和app.…

excel中添加下拉式菜单的方法

excel中添加下拉式菜单的方法 我使用的是excel2016的版本 选择下拉菜单中的内容就可以了

fabric-ca-client颁发Orderer节点证书

创建Orderer节点: function createOrderer {echo echo "Enroll the CA admin" echo mkdir -p organizations/ordererOrganizations/example.comexport FABRIC_CA_CLIENT_HOME=${PWD}/organizations/ordererOrganizations/example.comset -x fabric-ca-client enrol…

invalid sub button url domain hint 解决方法

腾讯官方的例子有问题! invalid sub button url domain hint 无效的子菜单域名错误!设置菜单的url的时候,习惯性的设置例如www.qq.com不应该这样写,而应该写全 https://www.qq.com/ 注意的是一定要包涵http或者https&#xff0c…

k8s pv,pvc无法删除问题

k8s pv,pvc无法删除问题 一般删除步骤为:先删pod再删pvc最后删pv 但是遇到pv始终处于“Terminating”状态,而且delete不掉。如下图: 解决方法: 直接删除k8s中的记录: 1 kubectl patch pv xxx -p {"metadata&…

error pulling image configuration: read tcp xxx.xxx.x.xxx:xx->xxx.xx.xxx.xx:xxx: read: connection

问题描述: 当我们使用docker pull拉取镜像的时候,会报error pulling image configuration: read tcp xxx.xxx.x.xxx:xx->xxx.xx.xxx.xx:xxx: read: connection类似这样的错误。 问题分析: 由于国内网络问题,无法连接到 dock…

ORA-15099: disk ‘/dev/asm_data‘ is larger than maximum size of 2097152 MBs

数据库版本 12c(12.1.0.2)ASM 12c以后支持超过2T 的盘记录ORA-15099 问题处理过程: ORA-15099:disk /dev/asm_data is larger than maximum size of 2097152 MBs查看磁盘的组的rdbms兼容版本 [gridl01testdb01 ~]$ asmcmd lsattr -G DATA -l Name …

固件错误Possible missing firmware解决办法

问题: W: Possible missing firmware /lib/firmware/rtl_nic/rtl8125a-3.fw for module r8169 解决方法: 1,进入如下这个地址,固件文件非常全面,找到适合自己的版本, https://git.kernel.org/pub/scm/linux…