实现步骤
- 加载BERT预训练好的模型和tokenizer
如果你已经将bert的预训练模型下载到本地,那么你可以从本地加载
tokenizer = BertTokenizer.from_pretrained('/home/wu/david/demo/bert_base_chinese')
model = BertModel.from_pretrained('/home/wu/david/demo/bert_base_chinese')
如果没有下载,也可以在线加载
tokenizer = BertTokenizer.from_pretrained('bert_base_chinese')
model = BertModel.from_pretrained('bert_base_chinese')
- 使用tokenizer.encode_plus将每个文本转换为BERT输入格式
inputs = tokenizer.encode_plus(text,add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)
解释一下每个参数的意义:
- text:要编码的文本字符串。
- add_special_tokens:指示是否在编码中添加特殊令牌(如 [CLS] 和 [SEP])。默认为 True。在BERT中,特殊令牌用于标识句子的开头和结尾,以及句子对任务中的分隔符。
- max_length:编码后的最大长度。如果文本长度超过此值,将进行截断。默认为 None,表示不进行截断。
- padding:指示是否对编码进行填充,使其达到 max_length。默认为 False。如果设置为 ‘max_length’,则会将文本填充到 max_length 的长度。填充通常在批处理中使用,确保批处理中的所有样本具有相同的长度。
- truncation:指示是否对文本进行截断,使其达到 max_length。默认为 False。如果设置为 True,则会将文本截断到 max_length 的长度。
- return_attention_mask:指示是否返回注意力掩码。默认为 False。如果设置为 True,则会返回一个表示注意力掩码的张量。注意力掩码标识出哪些令牌是真实输入,哪些是填充的,以便模型在处理时忽略填充令牌。
- return_tensors:指定返回的编码结果的张量类型。可选值为 “tf”(返回 TensorFlow 张量)或 “pt”(返回 PyTorch 张量)。默认为 None,表示返回标准 Python 列表。
inputs返回值有三部分:
- input_ids:编码后的文本的令牌 ID 列表,可以被模型直接使用。
- attention_mask:(可选)如果 return_attention_mask=True,则返回一个表示注意力掩码的列表。该列表标识出编码中哪些令牌应该被模型关注,哪些应该被忽略。
- token_type_ids:(可选)如果 return_token_type_ids=True,则返回一个表示句子类型的列表。在句对任务中,它标识出不同句子的令牌属于哪个句子。
- 使用BERT模型进行嵌入
#调用模型,传入上一步得到的inputsoutputs = model(**inputs)# 获取最后一个隐藏层表示,其形状为 [batch_size, sequence_length, hidden_size]last_hidden_state = outputs.last_hidden_state print(last_hidden_state.shape)token_embeddings = torch.squeeze(last_hidden_state, dim=0) #去掉batch_size这个维度,得到BERT模型编码后的每个令牌的隐藏状态表示
注:
- last_hidden_state 代表了BERT模型对输入文本的编码表示,其维度为[batch_size, sequence_length, hidden_size],其中,batch_size表示批处理中的样本数量;sequence_length表示每个样本的序列长度,即输入文本的令牌数量;hidden_size表示BERT模型的隐藏状态的维度大小,通常是预训练模型的参数之一。
- torch.squeeze() 是一个PyTorch函数,用于从张量中移除维度为1的维度,因此,token_embeddings 是经过挤压后的最后一个隐藏状态张量,形状为 [sequence_length, hidden_size],表示BERT模型编码后的每个令牌的隐藏状态表示
完整代码
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import pandas as pd
from torch.utils.data import Dataset,DataLoader# 本地加载BERT模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('/home/wu/david/demo/bert_base_chinese')
model = BertModel.from_pretrained('/home/wu/david/demo/bert_base_chinese')# 在线加载BERT模型和tokenizer
# tokenizer = BertTokenizer.from_pretrained('bert_base_chinese')
# model = BertModel.from_pretrained('bert_base_chinese')data=["欢迎学习bert!","万丈高楼平地起,成功只能靠自己","希望我们都有一个光明的未来!"]for index,text in enumerate(data):print("对第",index,"个文本进行词嵌入...")# tokenizer.encode_plus负责将每个文本转换为BERT输入格式inputs = tokenizer.encode_plus(text,add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)with torch.no_grad():# 使用BERT模型进行嵌入outputs = model(**inputs)# 获取最后一个隐藏层表示,其形状为 [batch_size, sequence_length, hidden_size]last_hidden_state = outputs.last_hidden_stateprint(last_hidden_state.shape)token_embeddings = torch.squeeze(last_hidden_state, dim=0)