Sentence-BERT实现文本匹配【回归目标函数】

devtools/2024/9/19 3:43:34/ 标签: bert, 回归, 人工智能

引言

上篇文章我们通过Sentence-Bert提出的分类目标函数来训练句子嵌入模型,本文同样基于Sentence-Bert的架构,但改用回归目标函数。

架构

image-20210923000654664

如上图,计算两个句嵌入 u \pmb u u v \pmb v v​之间的余弦相似度,然后可以使用均方误差(mean-squared-error)作为目标函数。
L = ∣ ∣ y − cosine_sim ( u , v ) ∣ ∣ 2 \mathcal L = ||y - \text{cosine\_sim}(\pmb u,\pmb v)||_2 L=∣∣ycosine_sim(u,v)2
这里的 y y y是真实标签。

回归目标函数的预测不再是整数标签1或0了,而可以为数值。比如对于给定的句子对,可以计算相似度得分。此时推理流程与训练完全相同。

实现

实现采用类似Huggingface的形式,每个文件夹下面有一种模型。分为modelingargumentstrainer等不同的文件。不同的架构放置在不同的文件夹内。

modeling.py:

from dataclasses import dataclassimport torch
from torch import Tensor, nnfrom transformers.file_utils import ModelOutputfrom transformers import (AutoModel,AutoTokenizer,
)import numpy as np
from tqdm.autonotebook import trange
from typing import Optional@dataclass
class BiOutput(ModelOutput):loss: Optional[Tensor] = Nonescores: Optional[Tensor] = Noneclass SentenceBert(nn.Module):def __init__(self,model_name: str,trust_remote_code: bool = True,max_length: int = None,num_classes: int = 2,pooling_mode: str = "mean",normalize_embeddings: bool = False,) -> None:super().__init__()self.model_name = model_nameself.normalize_embeddings = normalize_embeddingsself.device = "cuda" if torch.cuda.is_available() else "cpu"self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)self.max_length = max_lengthself.pooling_mode = pooling_modeself.loss_fct = nn.MSELoss()def sentence_embedding(self, last_hidden_state, attention_mask):if self.pooling_mode == "mean":attention_mask = attention_mask.unsqueeze(-1).float()return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(attention_mask.sum(1), min=1e-9)else:# clsreturn last_hidden_state[:, 0]def encode(self,sentences: str | list[str],batch_size: int = 64,convert_to_tensor: bool = True,show_progress_bar: bool = False,):if isinstance(sentences, str):sentences = [sentences]all_embeddings = []for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):batch = sentences[start_index : start_index + batch_size]features = self.tokenizer(batch,padding=True,truncation=True,return_tensors="pt",return_attention_mask=True,max_length=self.max_length,).to(self.device)out_features = self.model(**features, return_dict=True)embeddings = self.sentence_embedding(out_features.last_hidden_state, features["attention_mask"])if not self.training:embeddings = embeddings.detach()if self.normalize_embeddings:embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)if not convert_to_tensor:embeddings = embeddings.cpu()all_embeddings.extend(embeddings)if convert_to_tensor:all_embeddings = torch.stack(all_embeddings)else:all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])return all_embeddingsdef compute_loss(self, scores, labels):labels = torch.tensor(labels).float().to(self.device)return self.loss_fct(scores, labels.view(-1))def forward(self, source, target, labels) -> BiOutput:"""Args:source :target :"""source_embed = self.encode(source)target_embed = self.encode(target)scores = torch.cosine_similarity(source_embed, target_embed)loss = self.compute_loss(scores, labels)return BiOutput(loss, scores)def save_pretrained(self, output_dir: str):state_dict = self.model.state_dict()state_dict = type(state_dict)({k: v.clone().cpu().contiguous() for k, v in state_dict.items()})self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。

arguments.py:

from dataclasses import dataclass, field
from typing import Optionalimport os@dataclass
class ModelArguments:model_name_or_path: str = field(metadata={"help": "Path to pretrained model"})config_name: Optional[str] = field(default=None,metadata={"help": "Pretrained config name or path if not the same as model_name"},)tokenizer_name: Optional[str] = field(default=None,metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},)@dataclass
class DataArguments:train_data_path: str = field(default=None, metadata={"help": "Path to train corpus"})eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})max_length: int = field(default=512,metadata={"help": "The maximum total input sequence length after tokenization for input text."},)def __post_init__(self):if not os.path.exists(self.train_data_path):raise FileNotFoundError(f"cannot find file: {self.train_data_path}, please set a true path")if not os.path.exists(self.eval_data_path):raise FileNotFoundError(f"cannot find file: {self.eval_data_path}, please set a true path")

定义了模型和数据相关参数。

dataset.py:

from torch.utils.data import Dataset
from datasets import Dataset as dt
import pandas as pdfrom utils import build_dataframe_from_csvclass PairDataset(Dataset):def __init__(self, data_path: str) -> None:df = build_dataframe_from_csv(data_path)self.dataset = dt.from_pandas(df, split="train")self.total_len = len(self.dataset)def __len__(self):return self.total_lendef __getitem__(self, index) -> dict[str, str]:query1 = self.dataset[index]["query1"]query2 = self.dataset[index]["query2"]label = self.dataset[index]["label"]return {"query1": query1, "query2": query2, "label": label}class PairCollator:def __call__(self, features) -> dict[str, list[str]]:queries1 = []queries2 = []labels = []for feature in features:queries1.append(feature["query1"])queries2.append(feature["query2"])labels.append(feature["label"])return {"source": queries1, "target": queries2, "labels": labels}

数据集类考虑了LCQMC数据集的格式,即成对的语句和一个数值标签。类似:

Hello.	Hi.	1
Nice to see you.	Nice	0

trainer.py:

import torch
from transformers.trainer import Trainerfrom typing import Optional
import os
import loggingfrom modeling import SentenceBertTRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)class BiTrainer(Trainer):def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):outputs = model(**inputs)loss = outputs.lossreturn (loss, outputs) if return_outputs else lossdef _save(self, output_dir: Optional[str] = None, state_dict=None):# If we are executing this function, we are the process zero, so we don't check for that.output_dir = output_dir if output_dir is not None else self.args.output_diros.makedirs(output_dir, exist_ok=True)logger.info(f"Saving model checkpoint to {output_dir}")self.model.save_pretrained(output_dir)if self.tokenizer is not None:self.tokenizer.save_pretrained(output_dir)# Good practice: save your training arguments together with the trained modeltorch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承🤗 Transformers的Trainer类,重写了compute_loss_save方法。

这样我们就可以利用🤗 Transformers来训练我们的模型了。

utils.py:

import torch
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from typing import Tupledef build_dataframe_from_csv(dataset_csv: str) -> pd.DataFrame:df = pd.read_csv(dataset_csv,sep="\t",header=None,names=["query1", "query2", "label"],)return dfdef compute_spearmanr(x, y):return spearmanr(x, y).correlationdef compute_pearsonr(x, y):return pearsonr(x, y)[0]def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):"""Copied from https://github.com/UKPLab/sentence-transformers/tree/master"""assert len(scores) == len(labels)rows = list(zip(scores, labels))rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)print(rows)max_acc = 0best_threshold = -1# positive examples number so farpositive_so_far = 0# remain negative examplesremaining_negatives = sum(labels == 0)for i in range(len(rows) - 1):score, label = rows[i]if label == 1:positive_so_far += 1else:remaining_negatives -= 1acc = (positive_so_far + remaining_negatives) / len(labels)if acc > max_acc:max_acc = accbest_threshold = (rows[i][0] + rows[i + 1][0]) / 2return max_acc, best_thresholddef metrics(y: torch.Tensor, y_pred: torch.Tensor) -> Tuple[float, float, float, float]:TP = ((y_pred == 1) & (y == 1)).sum().float()  # True PositiveTN = ((y_pred == 0) & (y == 0)).sum().float()  # True NegativeFN = ((y_pred == 0) & (y == 1)).sum().float()  # False NegatvieFP = ((y_pred == 1) & (y == 0)).sum().float()  # False Positivep = TP / (TP + FP).clamp(min=1e-8)  # Precisionr = TP / (TP + FN).clamp(min=1e-8)  # RecallF1 = 2 * r * p / (r + p).clamp(min=1e-8)  # F1 scoreacc = (TP + TN) / (TP + TN + FP + FN).clamp(min=1e-8)  # Accuraryreturn acc, p, r, F1def compute_metrics(predicts, labels):return metrics(labels, predicts)

定义了一些帮助函数,从sentence-transformers库中拷贝了寻找最佳准确率阈值的实现find_best_acc_and_threshold

除了准确率,还计算了句嵌入的余弦相似度与真实标签之间的斯皮尔曼等级相关系数指标。

最后定义训练和测试脚本。

train.py:

from transformers import set_seed, HfArgumentParser, TrainingArgumentsimport logging
from pathlib import Pathfrom datetime import datetimefrom modeling import SentenceBert
from trainer import BiTrainer
from arguments import DataArguments, ModelArguments
from dataset import PairCollator, PairDatasetlogger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO,
)def main():parser = HfArgumentParser((TrainingArguments, DataArguments, ModelArguments))training_args, data_args, model_args = parser.parse_args_into_dataclasses()# 根据当前时间生成输出目录output_dir = f"{training_args.output_dir}/{model_args.model_name_or_path.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"training_args.output_dir = output_dirlogger.info(f"Training parameters {training_args}")logger.info(f"Data parameters {data_args}")logger.info(f"Model parameters {model_args}")# 设置随机种子set_seed(training_args.seed)# 加载预训练模型model = SentenceBert(model_args.model_name_or_path,trust_remote_code=True,max_length=data_args.max_length,)tokenizer = model.tokenizer# 构建训练和测试集train_dataset = PairDataset(data_args.train_data_path)eval_dataset = PairDataset(data_args.eval_data_path)# 传入参数trainer = BiTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=PairCollator(),tokenizer=tokenizer,)Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)# 开始训练trainer.train()trainer.save_model()if __name__ == "__main__":main()

训练

基于train.py定义了train.sh传入相关参数:

timestamp=$(date +%Y%m%d%H%M)
logfile="train_${timestamp}.log"# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=3 nohup python train.py \--model_name_or_path=hfl/chinese-macbert-large \--output_dir=output \--train_data_path=data/train.txt \--eval_data_path=data/dev.txt \--num_train_epochs=3 \--save_total_limit=5 \--learning_rate=2e-5 \--weight_decay=0.01 \--warmup_ratio=0.01 \--bf16=True \--eval_strategy=epoch \--save_strategy=epoch \--per_device_train_batch_size=64 \--report_to="none" \--remove_unused_columns=False \--max_length=128 \> "$logfile" 2>&1 &

以上参数根据个人环境修改,这里使用的是哈工大的chinese-macbert-large预训练模型。

注意:

  • --remove_unused_columns是必须的。
  • 通过bf16=True可以加速训练同时不影响效果。
  • 其他参数可以自己调整。
100%|██████████| 18655/18655 [1:17:23<00:00,  4.44it/s]
100%|██████████| 18655/18655 [1:17:23<00:00,  4.02it/s]
09/02/2024 21:02:41 - INFO - trainer - Saving model checkpoint to output/hfl-chinese-macbert-large-2024-09-02_19-45-12
{'eval_loss': 0.09294428676366806, 'eval_runtime': 56.1261, 'eval_samples_per_second': 156.825, 'eval_steps_per_second': 19.617, 'epoch': 5.0}
{'train_runtime': 4643.261, 'train_samples_per_second': 257.11, 'train_steps_per_second': 4.018, 'train_loss': 0.049199433276877584, 'epoch': 5.0}

这里训练了5轮,我们拿最后保存的模型output/hfl-chinese-macbert-large-2024-09-02_19-45-12进行测试。

参数忘改了,为了便于比较,实际上下面的结果是以3轮的训练结果验证的。

测试

test.py: 测试脚本见后文的完整代码。

test.sh:

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=0 python test.py \--model_name_or_path=output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193 \--test_data_path=data/test.txt

输出:

TestArguments(model_name_or_path='output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.77it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.89it/s]
max_acc: 0.8832, best_threshold: 0.794167
spearman corr: 0.7795 |  pearson_corr corr: 0.7668 | compute time: 22.25s
accuracy=0.883 precision=0.876 recal=0.893 f1 score=0.8843

测试集上的准确率达到88.3%,这种以回归目标函数进行训练的效果没有分类的好。

完整代码

完整代码: →点此←

本文代码是和某次commit相关的,Master分支上的代码随时可能会被优化。

参考

  1. [论文笔记]Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

http://www.ppmy.cn/devtools/108313.html

相关文章

Ubuntu2204配置连续失败后账户锁定

配置启用pam_faillock sudo nano /etc/pam.d/common-auth在最上面添加以下内容 auth required pam_faillock.so preauth silent audit auth sufficient pam_unix.so nullok try_first_pass auth [defaultdie] pam_faillock.so authfail auditsudo nano /etc/pam.d/…

【苍穹外卖】Day 6 HttpClient、wx小程序

1 HttpClient HttpClient 是 Apache Jakarta Common 下的子项目&#xff0c;可以用来提供高效的、最新的、功能丰富的支持 HTTP 协议的客户端编程工具包&#xff0c;并且它支持 HTTP 协议最新的版本和建议 HttpClient 是一个用于发送 HTTP 请求并接收响应的类或库&#xff0c;在…

Ansible自动化运维入门:从基础到实践的全面指南

Ansible自动化运维入门&#xff1a;从基础到实践的全面指南 随着IT基础设施的日益复杂&#xff0c;自动化运维已经成为提升效率、减少人为错误、优化资源的重要手段。在众多自动化工具中&#xff0c;Ansible因其简洁、易用、无代理&#xff08;Agentless&#xff09;等特性&am…

大数据决策分析平台建设方案(56页PPT)

方案介绍&#xff1a; 大数据决策分析平台旨在通过集成大数据处理、存储、分析与可视化技术&#xff0c;构建一个集数据采集、清洗、整合、分析、预测及决策支持于一体的综合性平台。该平台的目标是帮助企业实现提升决策效率与精准度&#xff0c;优化业务流程与资源配置&#…

easy_spring_boot Java 后端开发框架

Easy SpringBoot 基于 Java 17、SpringBoot 3.3.2 开发的后端框架&#xff0c;集成 MyBits-Plus、SpringDoc、SpringSecurity 等插件&#xff0c;旨在提供一个高效、易用的后端开发环境。该框架通过清晰的目录结构和模块化设计&#xff0c;帮助开发者快速构建和部署后端服务。…

Quartz.Net_依赖注入

简述 有时会遇到需要在IJob实现类中依赖注入其他类或接口的情况&#xff0c;但Quartz的默认JobFactory并不能识别具有有参构造函数的IJob实现类&#xff0c;也就无法进行依赖注入 需要被依赖注入的类&#xff1a; public class TestClass {public TestClass(Type jobType, s…

MQTT broker搭建并用SSL加密

系统为centos&#xff0c;基于emqx搭建broker&#xff0c;流程参考官方。 安装好后&#xff0c;用ssl加密。 进入/etc/emqx/certs,可以看到 分别为 cacert.pem CA 文件cert.pem 服务端证书key.pem 服务端keyclient-cert.pem 客户端证书client-key.pem 客户端key 编辑emqx配…

服务器托管是什么意思?优缺点详解

服务器托管是什么意思&#xff1f;服务器托管是一种服务&#xff0c;其中企业或个人将自己的服务器寄存在第三方数据中心。这种数据中心通常由专业的服务提供商运营&#xff0c;提供必要的物理和网络安全环境&#xff0c;以及稳定的电力和冷却系统。以下是服务器托管的详细介绍…

C#多态,Override和New的用法

一. 面向对象重要特性之多态 要掌握C#的Override和New关键字的用法&#xff0c;首先要理解多态&#xff1b;这里不赘述各种官方对多态的解释&#xff0c;下面给出个人直白理解&#xff1a; 父类F中声明一个方法M并用virtual修饰其为虚方法&#xff0c;子类S实现了相同签名的方…

html+css+js网页设计 故宫7个页面 ui还原度100%

htmlcssjs网页设计 故宫7个页面 ui还原度100% 网页作品代码简单&#xff0c;可使用任意HTML编辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 1…

P8687 [蓝桥杯 2019 省 A] 糖果

~~~~~ P8687 [蓝桥杯 2019 省 A] 糖果 ~~~~~ 总题单链接 思路 ~~~~~ 发现 k ≤ 20 , m ≤ 20 k\le20,m\le 20 k≤20,m≤20&#xff0c;考虑状压DP。 ~~~~~ 预处理 g [ i ] g[i] g[i] 表示第 i i i 包糖有哪几种糖果。 ~~~~~ 设 d p [ i ] dp[i] dp[i] 表示在 i i i 状态下…

微服务日常总结

1.当我们在开发中&#xff0c;需要连接多个库时&#xff0c;可以在yml中进行配置。 当在查询的时候&#xff0c;跨库时&#xff0c;需要通过DS 注解来指定&#xff0c;需要yml配置需要保持一致。 2. 当我们想把数据存入到clob类型中&#xff0c;需要再字段 的占位符后面加上j…

深度学习基础--卷积的变种

随着卷积同经网络在各种问题中的广泛应用&#xff0c;卷积层也逐渐衍生出了许多变种&#xff0c;比较有代表性的有&#xff1a; 分组卷积( Group Convolution )、转置卷积 (Transposed Convolution) 、空洞卷积( Dilated/Atrous Convolution )、可变形卷积( Deformable Convolu…

MATLAB控制USRP的附加功能安装包

请按照你手里的USRP型号选择对应的MATLAB附加功能包&#xff0c;安好后直接配置即可&#xff0c;这里包含了以下型号的USRP&#xff1a;X300,X310,X410(USRP x-eries)N300,N310, N320,N321, E312,E310,USRP- B200, B200mini, B200mini-i, B205mini-i或B210, SRP- N200, N210或U…

前端框架有哪些

前端框架有哪些 前端框架是用来帮助开发者构建用户界面和交互的库或工具。以下是一些流行的前端框架&#xff1a; React: 由 Facebook 维护的一个声明式、高效且灵活的 JavaScript 库&#xff0c;用于构建用户界面。 Vue.js: 一个渐进式 JavaScript 框架&#xff0c;用于构建…

SpringMVC 第一次复学笔记

服务器启动时&#xff0c;创建spring容器&#xff1b;dispatcherServlet启动时&#xff0c;直接创建springmvc容器初始化一次&#xff0c;实现了springmvc和spring的整合。 SpringMVC里的组件 处理器映射器&#xff08;HandlerMapping&#xff09;负责匹配映射路径对应的Handl…

如何编写Linux PCI设备驱动器 之一

如何编写Linux PCI设备驱动器 之一 PCI寻址PCI驱动器使用的APIpci_register_driver()pci_driver结构pci_device_id结构 如何查找PCI设备存取PCI配置空间读配置空间APIs写配置空间APIswhere的常量值共用部分类型0类型1 PCI总线通过使用比ISA更高的时钟速率来实现更好的性能&…

【STM32】外部中断

当程序正常运行执行main函数&#xff0c;此时如果外部中断来了&#xff0c;执行外部中断函数&#xff0c;实现相应的功能&#xff0c;然后就可以回到main. 一般stm32芯片每个引脚都有自己的外部中断&#xff0c;但是为了限制&#xff0c;会有一个中断线&#xff0c;对应一个中断…

【C++ | 设计模式】代理模式的详解与实现

1. 概念 代理模式&#xff08;Proxy Pattern&#xff09;是一种结构型设计模式&#xff0c;用于控制对对象的访问。它通过引入代理对象&#xff0c;间接地操作目标对象&#xff0c;从而实现对目标对象的控制。代理模式的核心思想是通过代理对象来控制对目标对象的访问。代理对…

flink---window

Window介绍 DataStream: https://nightlies.apache.org/flink/flink-docs-release-1.17/zh/docs/dev/datastream/operators/windows/ SQL: https://nightlies.apache.org/flink/flink-docs-release-1.17/zh/docs/dev/table/sql/queries/window-tvf/ 1、为什么需要Window?…