Bert基础(十九)--Bert实战:文本相似度匹配

devtools/2024/9/23 3:27:51/

文本匹配是指计算机系统识别和确定两段文本之间关系的任务。这个概念非常广泛,涵盖了各种场景,其中文本之间的关系可以是有相似度、问答、对话、推理等。在不同的应用场景下,文本匹配的具体定义可能会有所不同。
以下是几种常见的文本匹配任务及其特点:

  1. 文本相似度计算:计算两个文本之间的相似程度,例如判断两个句子是否表达相同或相似的意思。
  2. 问答匹配:将用户提出的问题与数据库中的答案进行匹配,以提供正确的信息。
  3. 对话匹配:在对话系统中,识别用户输入的文本与系统回复之间的匹配关系,以确保对话的连贯性和准确性。
  4. 文本推理:根据给定的文本内容推断出新的信息或结论。
    此外,像抽取式机器阅读理解和多项选择这样的任务,其本质也是文本匹配。在这些任务中,系统需要理解文本内容,并将其与问题或选项进行匹配,以确定正确答案。
    总之,文本匹配是自然语言处理中的一个重要概念,广泛应用于信息检索、机器翻译、文本生成、对话系统等多个领域。随着技术的发展,文本匹配的算法和模型也在不断进步,以更准确地理解和匹配文本内容。

本次先介绍最简单的文本相似度计算的任务,后面将其他的信息检索、机器翻译、文本生成、对话系统等任务进行实战。

基本步骤:

1 加载数据集
2 数据预处理
3 创建模型
4 创建评估函数
5 创建训练器
6 训练模型
7 评估
8 预测

1 加载数据集

在hugging face没有找到合适的数据集,所以找了一个资源上传在git上,大家可自取

dataset = load_dataset("json", data_files="/kaggle/input/sentence/sentence_pair.json", split="train")
dataset
Dataset({features: ['sentence1', 'sentence2', 'label'],num_rows: 10000
})

看下数据格式

dataset[10]
{'sentence1': '她是一个非常慷慨的女人,拥有自己的一大笔财产。','sentence2': '她有很多钱,但她是个慷慨的女人。','label': '1'}

2 数据预处理

datasets = dataset.train_test_split(test_size=0.2)
datasets
DatasetDict({train: Dataset({features: ['sentence1', 'sentence2', 'label'],num_rows: 8000})test: Dataset({features: ['sentence1', 'sentence2', 'label'],num_rows: 2000})
})

数据格式划分

import torchtokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def process_function(examples):tokenized_examples = tokenizer(examples["sentence1"], examples["sentence2"], max_length=128, truncation=True)tokenized_examples["labels"] = [int(label) for label in examples["label"]]return tokenized_examplestokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
tokenized_datasets

原始数据中label是字符串格式,需要int转换下

print(tokenized_datasets["train"][0])
{'input_ids': [101, 704, 1286, 1762, 100, 4638, 2207, 2238, 828, 2622, 8024, 4692, 4708, 1920, 6496, 1403, 3777, 6804, 6624, 1343, 511, 102, 1762, 100, 7353, 6818, 3766, 3300, 1920, 6496, 511, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': 0}

3 加载模型

from transformers import BertForSequenceClassification 
model = AutoModelForSequenceClassification.from_pretrained("bert-base-chinese")

4 创建评估函数

import evaluateacc_metric = evaluate.load("accuracy")
f1_metirc = evaluate.load("f1")def eval_metric(eval_predict):predictions, labels = eval_predictpredictions = predictions.argmax(axis=-1)acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metirc.compute(predictions=predictions, references=labels)acc.update(f1)return acc

5 创建训练器

train_args = TrainingArguments(output_dir="./cross_model",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,  # 验证时的batch_sizelogging_steps=10,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数learning_rate=2e-5,              # 学习率weight_decay=0.01,               # weight_decaymetric_for_best_model="f1",      # 设定评估指标report_to=['tensorboard'],load_best_model_at_end=True)     # 训练完成后加载最优模型
train_args

6 开始训练

from transformers import DataCollatorWithPadding
trainer = Trainer(model=model, args=train_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=eval_metric)trainer.train()

在这里插入图片描述

TrainOutput(global_step=750, training_loss=0.1850536205768585, metrics={'train_runtime': 302.9995, 'train_samples_per_second': 79.208, 'train_steps_per_second': 2.475, 'total_flos': 1552684115443200.0, 'train_loss': 0.1850536205768585, 'epoch': 3.0})

7 评估

trainer.evaluate(tokenized_datasets["test"])
{'eval_loss': 0.2481037676334381,'eval_accuracy': 0.912,'eval_f1': 0.8894472361809046,'eval_runtime': 7.6159,'eval_samples_per_second': 262.61,'eval_steps_per_second': 8.272,'epoch': 3.0}

8 预测

pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
result = pipe({"text": "昨天我做了个梦", "text_pair": "下雨了"})
result
{'label': '不相似', 'score': 0.989981472492218}

补充

上面是用分类的方式进行训练的,这里有个问题就是如果我们呢需要对一个句子去喝多个句子进行匹配找出最相似的句子就不行了,因为上面这种方式的结果没有对比的价值。所以我们需要修改下,使用MSE来计算损失。
需要修改的地方主要是两个,

数据预处理

def process_function(examples):tokenized_examples = tokenizer(examples["sentence1"], examples["sentence2"], max_length=128, truncation=True)tokenized_examples["labels"] = [float(label) for label in examples["label"]]return tokenized_examples

计算MSE需要计算损失的具体数值,所以要转换成浮点型

评估函数

def eval_metric(eval_predict):predictions, labels = eval_predictpredictions = [int(p > 0.5) for p in predictions]labels = [int(l) for l in labels]# predictions = predictions.argmax(axis=-1)acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metirc.compute(predictions=predictions, references=labels)acc.update(f1)return acc

预测的是一个具体的数值,不是【0,1】,所以计算准确率时要转换下

模型

num_labels=1

num_labels改为一,用回归的形式去预测

完整代码


http://www.ppmy.cn/devtools/25273.html

相关文章

记录k8s以docker方式安装Kuboard v3 过程

原本是想通过在k8s集群中安装kuboad v3的方式安装kuboard,无奈在安装过程中遇到了太多的问题,最后选择了直接采用docker安装的方式,后续有时间会补上直接采用k8s安装kuboard v3的教程。 1.kuboard安装文档地址: 安装 Kuboard v3 …

3D看车有哪些强大的功能?适合哪些企业使用?

3D看车是一种创新的汽车展示方式,它提供了许多强大的功能,特别适合汽车行业的企业使用。 3D看车可实现哪些功能? 1、细节展示: 51建模网提供全套汽车行业3D数字化解决方案,3D看车能够将汽车展示得更加栩栩如生&…

上门服务系统|上门服务小程序搭建流程

随着科技的不断进步和人们生活水平的提高,越来越多的服务开始向线上转型。传统的上门服务业也不例外,随着上门服务小程序的兴起,人们的生活变得更加便捷和高效。本文将为大家介绍上门服务小程序的搭建流程以及应用范围。 一、上门服务小程序搭…

iBarcoder for Mac:一站式条形码生成软件

在数字化时代,条形码的应用越来越广泛。iBarcoder for Mac作为一款专业的条形码生成软件,为用户提供了一站式的解决方案。无论是零售、出版还是物流等行业,iBarcoder都能轻松应对,助力用户实现高效管理。 iBarcoder for Mac v3.14…

windows驱动开发-I/O请求(二)

前面我们已经介绍了I/O请求的数据结构,接下来就是I/O请求的处理了。 IRP的处理 驱动例程收到IRP请求的时候,一般会使用一个巨大的switch语句,来处理每一个IOCTL,虽然前面我们简单的讲述了三种不同的I/O缓冲区方式,但…

【算法基础实验】图论-构建加权无向图

构建加权无向图 理论基础 在图论中,加权无向图是一种每条边都分配了一个权重或成本的图形结构。这种类型的图在许多实际应用中都非常有用,如路由算法、网络流量设计、最小生成树和最短路径问题等。 加权无向图的基本特征 顶点和边: 顶点&…

streampetr原版网络nuscenes数据pkl文件中的各字段含义

streampetr原版网络nuscenes数据pkl文件中的各字段含义 每帧数据都包含下列的信息 "token": 该帧数据的标识,具有唯一性 "prev": 该帧数据上一帧数据的token,如果没有就为"" "next": 该帧数据下一帧数据的toke…

【学习AI-相关路程-工具使用-NVIDIA SDK MANAGER==NVIDIA-jetson刷机工具安装使用 】

【学习AI-相关路程-工具使用-NVIDIA SDK manager-NVIDIA-jetson刷机工具安装使用 】 1、前言2、环境配置3、知识点了解(1)jetson 系列硬件了解(2)以下大致罗列jetson系列1. Jetson Nano2. Jetson TX23. Jetson Xavier NX4. Jetson…