基于NLP的医学搜索相关性判断

embedded/2025/1/8 19:54:46/

摘  要:

本文介绍了一个基于自然语言处理技术的医学搜索相关性判断实验。本项目将该任务转为自然语言处理中常见的多分类任务,采用“通用预训练模型”,“领域预训练模型”,“调参调优”,“模型融合”的四阶段范式,选择使用了优势互补的BERT-wwm-ext、ERNIE、Roberta_large三个通用预训练模型;针对二元句子查询和医疗背景,分别选择了Roberta_large_pair和Ernie-health-zh的领域专用预训练模型。针对该赛题中训练样本数据较少且不平衡这一难点,在经过数据扩充和数据交叉后,使用了对抗性权重扰动(AWP)对抗训练、随机权重衰减(SWA)、设置分层学习率、Multi-Sample Dropout、K折交叉验证等技术增强模型的鲁棒性,并通过融合不同优势的多个模型以及两阶段的伪标签学习,来对最终结果进行预测。

最终,本项目训练的模型在阿里天池竞赛中获得了Accuracy 0.887的成绩,在总共22979个参赛队伍中排名第1。该实验结果表明,自然语言处理技术在医学相关搜索相关性判断方面具有良好的应用潜力,对相关搜索的语义分析与归类具有积极意义。

 

  1. 项目介绍
    1. 赛题描述

Query(即搜索词)之间的相关性是评估两个Query所表述主题的匹配程度,即判断Query-A和Query-B是否发生转义,以及转义的程度也是自然语言处理中的的重要论题也是各个搜索引擎中需要计算的问题。Query的主题是指Query的专注点,判定两个查询词之间的相关性是一项重要的任务,常用于长尾Query的搜索质量优化场景,本任务数据集就是在这样的背景下产生的。数据集来源于中文医疗信息评测基准CBLUE[1],由阿里巴巴夸克医疗事业部提供。

    1. 赛题说明

日常生活中,如果我们在搜索引擎中输入一次查询,如图 1,搜索引擎常常会自动推荐很多相似的搜索查询推荐。此时即对于一个你输入的 Title,搜索引擎自动为你推荐了一些相关的 Query供你进行选择,用户可以直接点击搜索引擎推荐的查询链接,从而提高用户搜索体验,也能提高搜索引擎的准确性。本赛题中我们需要完成的任务即与这个过程相关。

图 1 阿里夸克App搜索示意图,搜索query为“小孩子打呼噜”

       在本赛题中,我们将Query和Title的相关度共分为3档(0-2),0分为相关性最差,2分表示相关性最好。

  1. 2分:表示A与B等价,表述完全一致
  2. 1分:B为A的语义子集,B指代范围小于A
  3. 0分:B为A的语义父集,B指代范围大于A;或者A与B语义毫无关联

2分:Query-A和Query-B主题等价,表述一致,例如表 1;

表 1 2分医学检索词相关性示例

Query-A

Query-B

解释

小孩子打呼噜是什么原因引起的

小孩子打呼噜什么原因

双眼皮怎么遗传

双眼皮遗传规律

黄体

女性黄体

点痣

点痣祛痣

1分:B为A的语义子集,B指代范围小于A,例如表 2

表 2 1分医学检索词相关性示例

Query-A

Query-B

解释

双眼皮遗传规律

内双眼皮遗传

尿酸高手脚酸痛

尿酸高 脚疼

海绵状血管瘤

多发性海绵状血管瘤

足藓

足藓如何治疗

室管膜囊肿与蛛网膜

左侧极蛛网膜囊肿

室管膜囊肿与蛛网膜为不同部位,但B为蛛网膜的一部分,子集。

板蓝根

好医生板蓝根

怀孕血糖高对胎儿有什么影响

怀孕初期血糖高对胎儿的影响

什么感冒药效果好

什么感冒药起效快

效果好有程度及快慢之分,故B为子集

搭桥手术和支架的区别

什么是支架和搭桥

A表述:搭桥的概念,支架的概念,二者的区别。B表述:搭桥的概念,支架的概念。B表述了A的一部分,故子集。

0分:第一种情况: B为A的语义父集,B指代范围大于A,例如表 3;

表 3 0分医学检索词B为A的语义父集相关性示例

Query-A

Query-B

解释

双眼皮遗传规律

单眼皮与双眼皮遗传

小孩子打呼噜什么原因

孩子打呼噜

牛蒡可以煮着吃

牛蒡如何吃

死了的大闸蟹能吃吗

死了的螃蟹能吃么

拔智齿后悔死了

拔了智齿

B是A必要前提条件,没有B的发生就没有A

光子嫩肤后注意事项

光子嫩肤的注意事项

白血病血常规有啥异常

白血病血检有哪些异常

第二种情况:A与B语义毫无关联,例如表 4;

表 4 0分医学检索词无关联相关性示例

Query-A

Query-B

解释

脑梗最怕的四种食物

脑梗患者吃什么好

牛蒡可以煮着吃

牛蒡有副作用吗

天津肿瘤医院官网

北京市肿瘤医院官网

二甲双胍副作用是啥

二甲双胍有副作用吗

麻腮风两针间隔时间

麻腮风接种时间

排卵前出血

排卵前期出血

A为排卵未开始,B为排卵刚开始时

石榴上火么

吃番石榴上火吗

石榴与番石榴为不同水果,易混淆

本次赛题的任务为预测两个query的相关性,相关性等级共有三档,因此该赛题应为三分类问题。又由于该题数据为两个query,且涉及到对语义的理解和语义关系的判断,因此这也是一个自然语言处理(NLP)问题。对于NLP问题,一般有两种常规处理步骤,即传统的机器学习方法或采用深度学习的方法。考虑到由于传统机器学习方法一般不能很好的理解文字的语义,这类浅层机器学习在NLP问题上很难有比较好的效果,而深度学习可以对语义进行理解。故本文主要采用后者深度学习的方法进行解决。

    1. 评测指标

本次赛题共有开放训练集数据15000条,验证集数据1600条,测试集数据1596条。本任务的评价指标使用准确率来进行评估,若用Accuracy表示准确率,用Ncorrect

表示预测正确的Query数,用Ntotal

表示所有Query数,则有:

Accuracy=NcorrectNtotal

    1. 本文结构

本文一共分为6章。第一章为项目的简介,介绍了本赛题的背景,内容,示例与评测指标。第二章为数据分析与处理,主要介绍了本次赛题中数据的分析与可视化结果,并给出了数据预处理和数据清洗的方法。第三章为模型搭建,详细介绍了本项目中采用的模型及其结构,并介绍了本项目使用的调参手段。第四章为实验结果,介绍了经过多次尝试后最终的线上排名与成绩。第五章为总结与展望,对本项目做了详细的总结,并指出后续还有哪些可以提高的空间。最后,第六章为参考文献。

  1. 数据分析与处理

数据的分析和处理对模型训练以及最终结果起着至关重要的影响。在本节,我们主要介绍了对数据的可视化分析,数据的清洗以及构建数据集入模。

    1. 数据分析

首先,将训练数据读入并查看,可以看到如图 2的结果。训练数据一共有一万五千行,每行包括四列,分别为id,需要判断的两个查询以及一个标签。并且,对于每个数据来说,连续的数据的query1是相同的,即存在多个query围绕同一话题的情况。这是一个比较重要的特征,如果在后续进行验证集验证过程中不进行分块选取,则会导致验证结果受影响,模型欠拟合的情况。

图 2 训练数据示意图

       由于存在两个需要查询的字符串,接着我们对两个字符串的长度进行分析,结果如图 3所示。绝大部分查询的长度都位于5~15个字符之间,并且query1和query2的长度以及分布都大致相同。

图 3 query长度分布图

       接着,我们对标签进行计数分析,作柱状图如图 4。可以看到,标签的分布很不均匀,超过60%的数据的标签都为0,数据存在标签不平衡现象。

图 4 标签分布图

最后,由于本次任务为NLP类任务,数据集比较干净,查看数据集中的空数据以及非法数据如图 5。可以发现,本数据集中仅有一条数据为脏数据(label为非法值),将其移除即可。

图 5 非法数据筛选

       再查看其他数据,具体如图 6,可以发现其他所有数据都是干净的,合法的,不存在非法数据的情况,因此数据分析可以到此结束。

图 6 其他数据的非法数据分析

    1. 数据增强处理

考虑本次比赛中的数据量相对较少,并且存在标签分布不均衡的问题,可以考虑进行数据增强。一般来说,数据增强分为内部增强和外部增强。

由于对于一个相同的query1,存在多个不同的query2,并且等价的语义存在传递性,则若有querya=queryb

, querya=queryc

,我们可以得到queryb=queryc

。对于label为1的情况,语义小于关系是偏序关系,此时若存在querya<queryb

, queryb<queryc

则可扩充querya<queryc

。此外,可以考虑将等价语义与小于语义进行组合扩充。同理,语义不同也可以进行扩充,具体如图 7。

图 7 内部数据增强关系示意图

考虑到0,1,2标签实际分布的比例,数据的扩充不应逆转原数据集中的比例,而又要弥补原数据集中数据不够,标签分布不均衡的问题。因此,需要对生成的数据按照比例进行采样。考虑到正负样本均衡,根据生成的类别数量,最终将标签0、1、2增强比例分别设置为如表 5。

表 5 数据增强标签比例表

标签

比例

0

0.15

1

0.2

2

0.3

    1. 文本入模预处理

我们的输入数据主要是文本,无法直接输入模型,所以在输入模型之前需要对文本内容进行数据预处理,主要是分词和编码的操作,获得编码信息后,再与数据的标签label组合在一起,构建新的数据集。不妨以Bert[11]来展示一个合法输入的过程,具体如图 8所示。

图 8 Bert的输入示例

首先,我们需要将数据集的格式处理为[CLS]Sentence 1[SEP]Sentence 2[SEP]的格式。可以通过载入预训练模型的tokenizer包来进行分词和编码,同时要确保拼接的大小不要超过512(Bert处理最大长度为512),这里我们在数据探索阶段统计过query1和query2的文本长度,两者最大长度相加不会超过,因此可以不做拆分处理。另外提一下如果在其他情况下如像段落等较长的文本时,则需要对其进行处理,如按句子拆分等。比如doc由[ sentence1, sentence2, sentence3 ]组成,则我们需要处理为:

  1. query [SEP] sentence1
  2. query [SEP] sentence2
  3. query [SEP] sentence3

在本次比赛中,我们使用tokenizer.encode_plus()方法来进行编码,tokenizer.encode_plus() 的编码方式比tokenizer.encode()在文本分类上的编码方式要好,在中文分类数据集上会有1个点左右的差别,符合我们的需求。tokenizer.encode_plus()会返回所有的编码信息:

  1. input_ids:单词在词典中的编码
  2. token_type_ids:区分两个句子的编码(上句全为0,下句全为1)
  3. attention_mask:指定对哪些词进行self-Attention操作

获得编码信息后,便可以和数据标签组合在一起,完成了特征的转换,具体的实现过程如图 9所示。

图 9 文本预处理与特征转换过程实现

训练模型一般要先处理数据输入问题和预处理问题,pytorch提供了几个有用的工具:torch.utils.data.Dataset类和torch.utils.data.DataLoader类。

主要流程是先把原始数据转变成torch.utils.data.Dataset类,随后再把得到的 torch.utils.data.Dataset类当作一个参数传递给 torch.utils.data.DataLoader类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用,大大方便了我们进行模型的训练。接着我们要将特征转换获得的feature转变成Dataset类,首先通过继承来定义属于我们的数据集,然后实现必要的三个方法(__init__ 方法, __getitem__ 方法和 __len__ 方法),具体实现如图 10所示

图 10 BuildDataSet类

接着将我们前面处理得到的数据feature转换为Dataset类,传入DataLoader中获取数据加载器,shuffle参数设置为True,每次训练打乱训练集,从而防止过拟合,完整的实现如图 11所示。

图 11 数据的加载器

  1. 模型搭建
    1. 整体思路

经过第一章和第二章的分析,我认为本赛题任务的难点主要有如下四点:

  1. 训练集数据量比较小,模型容易过拟合;
  2. 训练集标签分布不均衡,尽管通过内部数据增强扩充了数据,但数据的标签仍分布不平衡,而模型预测更倾向于样本数量多的标签类别,这可能会导致预测准确度下降;
  3. 线下指标并不能和线上指标达成一致,故只能通过线上指标来挑选模型,然而线上提交次数比较有限;
  4. 数据集中的数据存在比较明显的聚集现象,即存在多个query围绕同一title的现象,这对本地划分验证集有这更高的要求,需要均匀的划分验证集,否则验证效果会比较差。

本项目模型搭建的整体思路如下,考虑到本任务为自然语言处理中常见的多分类任务,我采用“通用预训练模型->领域预训练模型->调参调优->模型融合”的四阶段范式,主要选择使用了优势互补的BERT-wwm-ext[2]、ERNIE[3]、Roberta_large[2]三个通用预训练模型,针对二元句子查询和医疗背景,分别选择了Roberta_large_pair[3]和Ernie-health-zh[5]的领域专用预训练模型。针对该赛题中训练样本数据较少且不平衡这一难点,我使用了对抗性权重扰动(AWP)[6]对抗训练、随机权重衰减(SWA)[7]、设置分层学习率、Multi-Sample Dropout、K折交叉验证等技术增强模型的鲁棒性,并通过集成融合不同优势的多个模型以及两阶段的伪标签学习,来对最终结果进行预测,最终取得线上Accuracy 0.887的成绩,在22979个参赛选手中排名第1,位次比0.004%。整体流程如图 12所示:

图 12 整体流程图

具体来说,我们先分别使用10折BERT-wwm-ext模型,单折ERNIE模型,单折Roberta_large模型,5折Roberta_large_pair模型和单折Ernie-health-zh模型训练出5个单模型。然后经过加权平均将模型进行一次融合并对测试数据进行标注,标注完成后筛选置信度高的数据作为伪标签数据与原数据混合,使用这个数据在Roberta_large_pair模型上使用5折交叉验证进行训练。得到结果后再进行一次伪标签的筛选并与原数据进行融合,使用融合后的数据在Ernie-health-zh模型上使用5折交叉验证进行训练,得到最终结果。

本报告后续内容将按照先后顺序对上图中的内容依次进行介绍。

    1. 通用预训练模型

预训练是自然语言处理赛题中比较常见的方法。通过使用预训练模型,可以大幅度提升模型的准确度,并降低学习的负担。接下来将首先简要介绍“预训练”,接着将介绍本项目中使用的3个通用预训练模型。

      1. 预训练

在实际生产中,我们经常会遇到“标注资源稀缺而无标注资源丰富”的现实,某种特殊的任务只存在非常少量的相关训练数据,以至于模型不能从中学习总结到有用的规律。例如,如果我想对一批法律领域的文件进行关系抽取,则需要投入大量的精力在法律领域的文件中进行关系抽取的标注,然后将标注好的数据投入模型进行训练。但是即使是我标注了几百万条这样的数据(实际情况中,在一个领域内标注几百万条几乎不可能,因为成本非常高),和动辄上亿的无标注语料比起来,还是显得过于单薄。“预训练”这时便可以派上用场。

对于一系列的实际问题,一般认为数据之间存在共性和特性。而预训练思想将训练任务拆解成共性学习和特性学习两个步骤。“预训练”的做法是将大量低成本收集的训练数据放在一起,经过某种预训方法去学习其中的共性,然后将其中的共性“移植”到特定任务的模型中,再使用相关特定领域的少量标注数据进行“微调”,这样的话,模型只需要从“共性”出发,去学习该特定任务的“特性”部分即可。例如,如图 13,考虑A,一个完全不懂英文的人去做英文法律文书的关键词提取的工作会完全无法进行,或者说需要非常多的时间去学习。但是如果让一个英语为母语但是没接触过此类工作的人(我们称他为B)去做这项任务,B可能只需要相对比较短的时间学习就可以上手这项任务。在这里,英文知识就属于“共性”的知识,这类知识不必要只通过英文法律文书的相关语料进行学习,而是可以通过大量其他英文语料来进行学习。

图 13 一个预训练的例子

因此,可以将预训练类比成学习任务分解:在上面这个例子中,如果我们直接让A去学习这样的任务,这就对应了传统的直接训练方法。如果我们先让A变成B,再让B去学习同样的任务,那么就对应了“预训练+微调”的思路。很显然,从“英文法律文书”出发的学习速度大于从“英文”出发的学习速度, 更大于从0出发的学习速度。其实采用“预训练”思路的B和C,不仅学习速度高于A,更重要的是,他们的学习效果往往好于A。

      1. BERT-wwm-ext

Whole Word Masking (wwm),即为全词mask或整词mask,是谷歌在2019年5月31日发布的一项BERT的升级版本,主要更改了原预训练阶段的训练样本生成策略。简单来说,原有基于WordPiece的分词方式会把一个完整的词切分成若干个子词,在生成训练样本时,这些被分开的子词会随机被mask。而在全词mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask。

BERT-wwm-ext[2]是由BERT-base-Chinese改进而来,由于谷歌官方发布的BERT-base-Chinese中,中文是以字为粒度进行切分,没有考虑到传统NLP中的中文分词(CWS)。 我们将全词mask的方法应用在了中文中,使用了中文维基百科(包括简体和繁体)进行训练,并且使用了哈工大LTP[8]作为分词工具,即对组成同一个词的汉字全部进行mask。

考虑到在本项目中使用了多个模型,为了方便训练减少代码重复,在代码实现过程中我封装了train函数,以方便不同模型的调用和训练过程。Train函数的入口为已经设置好的config类,以及各个数据集和tokenizer的路径。所有数据会先通过数据预处理分词成对应可以入模的数据后再进行训练。具体的分词过程如上一小节所示。所有训练的入口函数都为train函数。完成数据处理之后会调用model_train函数正式进入模型的训练。其中,train函数如图 14所示。

图 14 train函数

接着,定义优化器和损失函数,我们所使用的预训练模型将损失计算也封装了进去,所以不需要在这里进行损失函数的定义和损失计算,这个只是指预训练的损失计算。具体如图 15所示。

图 15 优化器与损失函数

接着,我们定义一些必要的变量及学习率超参。学习率是神经网络优化时的重要超参数,可使用warmup预热学习率策略来针对learning_rate的提升。由于刚开始训练的时候,模型的权重是随机初始化的,此时若选择一个较大的学习率,可能会导致模型不稳定(振荡),选择warmup策略可以使开始训练的几个epoch或者一些step内学习率较小,从而使模型趋于稳定,待模型稳定后再选择预设的学习率进行训练。具体如图 16所示,我们设计一个learning rate scheduler,在训练过程中更新learning rate。

图 16 学习率warmup策略与一些必要变量

接着,进行如图 17的训练过程。将模型切换为训练模式,每次训练前DataLoader数据加载器可以直接解决数据输入和预处理问题,此外,我们设置了每次返回的数据打乱,防止过拟合。将数据输入模型,获取模型输出outputs和计算的loss,再将model梯度清零,并更新learning rate schedule。在训练过程中,每20个batch我们评估一下模型,并保存当前最优的模型。

图 17 训练模型

对于模型中需要进行评估来确定模型准确度的需求,我们设定了model_evaluate函数。只需要在model(dev)之前加上model.eval()切换为测试模式,并把BN和DropOut固定住,只使用训练好的值即可。具体如图 18所示。

图 18 模型的评估

至此,我们已经实现了模型的训练过程以及评估过程,为了方便训练,我们会将模型持久化存储于磁盘上,此时还需补充加载模型以及保存模型的工具函数即可,具体如图 19所示。

图 19 模型的保存和加载的工具函数

至此,即可以比较方便的进行模型的训练。后续每个模型的训练过程都类似,都通过调用train函数完成,后面不做赘述。接着,使用Pytorch构建模型,并在扩充完的数据集上训练,如图 20和图 21所示。

图 20 使用Pytorch构建Bert模型

图 21 训练Bert模型

       将得到的单模型进行线上提交,得到如图 22的结果。

图 22 Bert单模型线上分数

      1. ERNIE

Google提出的BERT模型,通过随机屏蔽15%的字或者word,利用Transformer的多层self-attention双向建模能力,在各项nlp下游任务中(如句子对分类任务, 单句分类任务, 查询回答任务) 都取得了很好的成绩。但是,BERT 模型主要是聚焦在针对字或者英文word粒度的完形填空学习上面,没有充分利用训练数据当中词法结构,语法结构,以及语义信息去学习建模。比如“我要买苹果手机”,BERT模型将“我”,“要”,“买”,“苹”,“果”,“手”,“机”每个字都统一对待,随机mask,丢失了“苹果手机”是一个特定的专有名词这一信息,这是词法信息的缺失。同时我+买+名词是一个非常明显的购物意图的句式,BERT没有对此类语法结构进行专门的建模,如果预训练的语料中只有“我要买苹果手机”,“我要买华为手机”,哪一天出现了一个新的手机牌子比如栗子手机,而这个手机牌子在预训练的语料当中并不存在,没有基于词法结构以及句法结构的建模,对于这种新出来的词是很难给出一个很好的向量表示的。如图 23所示,而ERNIE[3]通过对训练数据中的词法结构,语法结构,语义信息进行统一建模,极大地增强了通用语义表示能力,在多项任务中均取得了大幅度超越BERT的效果。

图 23 Bert与ERNIE的Transform过程mask对比图

此外,在医疗数据中,往往会存在较多的实体概念;此外文本相似度作为问答任务的子任务,数据描述类型也偏向于口语。ERNIE通过对词、实体等语义单元的掩码,使模型学习完整概念的语义表示,其训练语料包括了百科类文章、新闻资讯、论坛对话。因此ERNIE能够更准确表达语句中实体的语义,且符合口语的情景。

接着,使用Pytorch构建模型,并在扩充完的数据集上训练,如图 24和图 25所示。

图 24 使用Pytorch构建ERNIE模型

图 25 训练ERNIE模型

将得到的单模型进行线上提交,得到如图 26的结果。

图 26 ERNIE单模型线上分数

      1. Roberta_large

与BERT相比,RoBERTa基本没有什么太大创新,主要是在BERT基础上做了几点调整:训练时间更长,batch size更大,训练数据更多,移除了next predict loss,训练序列更长,动态调整Masking机制。由于Roberta采用了更大的数据进行训练,并且采用了更合理的微调策略,因此Roberta_large[2]是目前大多数NLP任务的“State Of The Art”模型。在Roberta_large中文版本使用了动态掩码、全词掩码,增加了训练数据,并改变了生成的方式和语言模型的任务。因此,我们认为在医疗文本上,Roberta_large能更好地对文本进行编码,故选择Robert作为最后一个通用模型。

接着,使用Pytorch构建模型,并在扩充完的数据集上训练,如图 27和图 28所示

图 27 使用Pytorch构建Roberta_large模型

图 28 训练Roberta_large模型

将得到的单模型进行线上提交,得到如图 29的结果。

图 29 Roberta_large单模型线上分数

    1. 领域预训练模型

在上一小节中介绍了通用预训练模型,即对于所有中文的预训练模型。考虑到本赛题的情景是在医疗下的二元对话,故分别引入二元对话预训练模型和医学医疗下的预训练模型,具体见下。

      1. Roberta_large_pair

考虑到本赛题的任务为对一个句子Pair进行相关性检索,为了提升对句子对二元关系的专用领域特征学习准确度,特选了针对二元关系的领域预训练模型。Roberta_large_pair [3]是针对文本对任务提出的专门模型,能够较好地处理语义相似度或句子对问题。因此,在二元句子文本相似度任务上,往往能够取得更好的结果。

接着,使用Pytorch构建模型,并在扩充完的数据集上训练,本节中的代码与上一节中的基本相同,此处不做赘述,仅展示如图 30所示的训练过程和如图 31所示的线上提交结果。

图 30 训练Roberta_large_pair模型

图 31 Roberta_large_pair单模型线上分数

      1. Ernie-health-zh

此外,考虑到本赛题是在医疗背景下的专业性领域赛题,需要一定程度上对医疗专有名词的理解,在此引入了Ernie-health-zh[5]数据集。Ernie-health-zh是使用与ERNIE相同的结构在大量的医疗数据上训练得到的模型,对医疗领域的数据有比较高的准确度。

接着,使用Pytorch构建模型,并在扩充完的数据集上训练,本节中的代码与上一节中的基本相同,此处不做赘述,仅展示如图 32所示的训练过程和如图 33所示的线上提交结果。

图 32 训练Ernie-health-zh模型

图 33 Ernie-health-zh单模型线上分数

    1. 调参调优

在上一节中,我们对各个单模型借助预训练模型进行了建模并进行了线上提交,汇总结果如表 6所示。

表 6 各个单模型线上结果

模型名

模型类别

线上得分(Accuracy)

BERT-wwm-ext

通用模型

0.8321

ERNIE

通用模型

0.8440

Roberta_large

通用模型

0.8459

Roberta_large_pair

领域模型

0.8498

Ernie-health-zh

领域模型

0.8373

       可以看到,各个单模型的结果并不足够好,由于各个单模型都有各自的特点,只能胜任部分特征的任务。接下来,将对各个模型进行调参调优以提高各个模型的鲁棒性以及准确性。

      1. 外部数据增强

数据增强分内部数据增强和外部数据增强。其中内部数据增强是通过内部数据的交叉,传递等关系生成新的数据,并没有引入外来的数据集。而外部数据增强则通过引入外部数据,增加有效训练数据量,使模型拥有更强的能力。但若引入的外部数据与本身任务数据存在偏离,则会导致模型的预测结果出现偏离,因此选择恰当合适的外部数据是本部分的难点。

经过调研,本次使用的外部数据来源为平安医疗科技疾病问答迁移学习比赛,如图 34所示。其数据集是也是医疗领域的文本相似度任务,任务类型与标签格式都与本次赛题比较相似。但如图 35和图 36所示,但其query长度,涉及的病种、以及语言形式与该比赛数据存在差距。

图 34 外部增强数据示例

图 35 外部增强数据长度分布

图 36 外部增强数据标签分布

为了能够选出符合该任务的外部数据,我们采用模型选择的形式对外部数据进行筛选。首先使用比赛中的原始数据构建筛选模型。考虑到外部数据也是一种领域专用数据集,因此我们通过另外一个领域专用模型(Roberta_large_pair)来进行筛选,然后后对所有数据进行预测。取筛选预测的概率处于0.20~0.80

之间的数据作为外部增强数据。

      1. 数据交叉

数据交叉即喂给不同模型不同但有交叉的数据。通过数据交叉,即训练时使用不同的数据进行组合,能够在数据层面增加模型簇的多样性。我们的目标是,将难的数据集(外部数据)给更强大的模型而不是都喂给小模型,使小的模型能够精准预测,大的模型更具鲁棒性。 此外,也可以通过计算KL散度来计算不同组合的模型之间的差异,并通过计算模型两两间的差异和,来获得模型簇整体的多样性,并以这样的形式,选取数据组合。本项目中采用比较简单的方法,我们将结合模型的理论特点和实际线上指标为各个模型分配进行数据交叉后的数据集。此外,在医疗文本相似度任务中,交换两个文本的数据不会改变该文本对的标签。但是对于Bert模型来说,交换文本对的位置,会改变位置编码,能使模型从不同的角度观察这两个文本的相似性。在测试数据增强时,通过计算原始文本对与交换文本对的输出概率的平均值,能够使模型更好地在测试数据上的进行预测。

具体如表 7所示,其中“原始数据”是指原始的train.json数据,“外部数据”是指我们在上一小节中引入的外部数据,“传递性增强”是指在2.2节中通过传递性增强的同类别数据,“新类别增强”是指在2.2节中通过传递性增强的非同类别数据,“随机交换”是指随机交换两个Query对的顺序。

表 7 数据交叉表

模型

原始数据

外部数据

传递性增强

新类别增强

随机交换

BERT-wwm-ext

True

False

True

True

True

ERNIE

True

False

True

True

False

Roberta_large

True

True

True

False

True

Roberta_large_pair

True

True

True

True

True

Ernie-health-zh

True

True

False

True

False

      1. Multi-sample dropout

在训练过程中,由于Bert后接了dropout层。为了加快模型的训练,我们使用multi-sample dropout技术[10]。传统dropout在每轮训练时会从输入中随机选择一组样本(称之为dropout样本),而multi-sample dropout会创建多个dropout样本,然后平均所有样本的损失,从而得到最终的损失。这种方法只要在dropout层后复制部分训练网络,并在这些复制的全连接层之间共享权重就可以了,无需新运算符。通过综合M个dropout样本的损失来更新网络参数,使得最终损失比任何一个dropout样本的损失都低。这样做的效果类似于对一个minibatch中的每个输入重复训练M次。因此,它大大减少了训练迭代次数。一个简单的例子流程如图 37所示。

图 37 不同dropout方法对比图

在本赛题中,我们通过对Bert后的dropout层进行多次sample,并对其多次输出的loss进行平均,这不仅增加了dropout层的稳定性,同时也使得Bert后面的全连接层相较于前面的Bert_base部分能够得到更多的训练。具体的代码实现如所示。

图 38 Multi-sample dropout实现代码

      1. 对抗性权重扰动对抗训练

对抗性权重扰动[6](Adversarial Weight Perturbation, AWP)是许多数据挖掘比赛中常用的微调手段,可以提升模型的抗干扰能力以及泛化能力。我们在此项目中引入了对抗性权重扰动对抗训练,下文将具体介绍。

我们知道模型训练是一个ERM(经验风险最小化,Empirical Risk Minimization)的过程,而对抗训练就是为了增强模型的抗干扰能力。顾名思义,对抗训练就是用对抗样本去训练模型,而通过对原始训练数据添加噪声便得到了对抗样本。从优化的角度来看,对抗训练可以被形式化为一个min-max优化问题,即内层最大化,外层最小化问题,公式如下式(1)所示,其中,L

表示损失函数,fθ

表示模型,xi

表示原始数据样本,yi

表示标签,xi'

表示对抗样本。

minθ1ni=1nmax|xi'-xi|pϵLfθxi',yi

(1)

       从上述公式就能清晰的看出对抗训练的流程,和正常训练类似,比如训练n轮,选取一个mini-batch的数据,对这个batch内的每一个数据点生成一个对抗样本,然后作为对抗训练的数据,代入到模型里更新模型的参数。但在以往的对抗训练中经常会遇到以下几个问题:

  1. 训练速度低:因为对抗训练需要去求解内层最大化问题来产生对抗样本,而通常方式一般需要10步到20步的迭代,因此相对于正常训练,对抗训练会慢10倍到20倍,所以对抗训练在大数据集上很难采用。
  2. 特异性的泛化能力:因为做对抗训练往往使用某种特定的攻击去产生对抗样本,但是这么训练得到的模型可能只对当前的攻击有防御效果。
  3. 较大的泛化误差:正常训练很容易在训练集上得到一个很高的准确率,而且在测试集上的表现也不差。但是对抗训练往往会面临一个问题,在训练集上的效果非常好,但是在测试集上却没有很好地结果。

针对这些问题,对抗性权重扰动出现了。他通过使用即时生成的对抗样本来表征权重损失情况,发现在对抗训练中,更平坦的权重损失通常能得到更小的鲁棒泛化差距。并明确规范化对抗训练的权重损失,形成一种双扰动机制,注入最坏情况下的输入和权重扰动。具体的代码实现比较简单,具体如图 39所示。

图 39 AWP代码实现例子

在调用时只需在训练的循环中调用AWP类的attack_backward方法即可,具体如图 40所示。

图 40 AWP的调用

       在本项目中,我们只对权重进行扰动,不对输入进行扰动,不创建对抗样本。并且我们为了保证对抗是有效的,我们令模型训练完60%的epoch才开始进行扰动。

      1. 随机权重衰减

考虑到本项目的数据集比较小,并且有相当一部分数据是由其他数据组合生成出来的,数据间的耦合性比较强,容易发生过拟合。因此,可以考虑使用随机权重衰减(SWA)[7][12]方法降低过拟合。权重衰减即为L2

范数正则化。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。如式(2),范数正则化在模型原损失函数基础上添加L2

范数的惩罚项,从而得到训练所需要最小化的函数,其中C0

表示原始代价函数,后面的项为L2

正则化项,λ

为权重衰减系数。

C=C0+λ2nωω2

(2)

       由于Python的Transformer库中已经提供了随机权重衰减的设置与接口,此处对实现过程不做赘述。我们将权重衰减系数的范围设置为λ∈1e-4,1e-5

,这个范围将保证模型有比较好的泛化能力,并且也能比较好的完成拟合。

      1. 分层学习率

在使用BERT或者其它预训练模型进行微调并下接其它具体任务相关的模块时,会面临这样一个问题,BERT由于已经进行了预训练,参数已经达到了一个较好的水平,如果要保持其不会降低,学习率就不能太大,而下接结构是从零开始训练,用小的学习率训练不仅学习慢,而且也很难与BERT本体训练同步。因此在训练时候就需要对预训练层设置较小学习率,对下接层设置较大学习率。此外,将学习率进行分层也有助于提高模型的泛化能力。

       考虑到Python的Transformer库中已经为分层学习率设置了比较方便的接口,可以直接调用实现。具体来说,我们在config中设置了变量diff_learning_rate,如果该变量为True,则启用分层学习率;若为False,则不启用。接着,只需要按照层数,修改原代码中的optimizer_grouped_parameters变量,根据先后顺序设置分层的学习率即可。具体实现如图 41中所示,其中,lr表示的是学习率,weight_decay为上一小节中介绍的随机的权重衰减。

图 41 分层学习率与权重衰减的实现

      1. 伪标签

考虑到本项目中的数据相对比较少,我们希望在仅有这些数据的基础上更充分的利用这些数据得到更准确的模型。基于此,我们采用了伪标签技术。伪标签技术是半监督学习中的一个概念,能够帮助模型更好的从无标注的信息中进行学习。与完全的无监督学习相比,半监督学习拥有部分的标注数据和大量的未标注数据,这种形式也更加适合现实场景和竞赛场景。简单来说,伪标签就是用模型已经标注好的数据,去再次喂给模型输入,使用模型标签的结果再次训练模型,从而提高模型的准确率。

本项目中,我们不使用常规的伪标签操作,我们使用一个新的方式。考虑到原数据和伪标签数据的不同,我们认为原训练数据的权重或者说置信度是要比我们的伪标签高的,因此需要在模型训练过程中进行对不同数据来源的判断。我们通过修改模型的损失函数实现这个功能,如式(3)所示,其中lossorigin

表示原数据的loss,lossunlabeled

表示未标记的数据的loss,α

为伪标签权重系数,并且为了保证原数据的权重较高,我们设置α∈0,1

Loss=lossorigin+α×lossunlabeled1+α

(3)

       具体的训练过程如图 42所示。

图 42 伪标签训练示意图

       具体来说我们首先使用标记数据训练模型M

,接着使用模型M

对无标签数据进行预测,得出预测概率P

。接着,为了保证伪标签的有效性,我们通过置信度阈值θ

进行过滤,只将预测概率高于θ

的数据作为伪标签。然后,我们使用新的模型损失函数,组合原数据与伪标签数据,训练新模型M'

。伪标签实现的代码比较简单,在模型训练完进行预测时根据阈值进行过滤即可,此处不赘述。在本文中,为了提高模型训练的泛化性避免过拟合,我们将伪标签方法与后文中交叉验证和模型融合一起组合使用。

      1. 交叉验证

交叉验证是数据挖掘比赛中的常用方法,不仅可以提高模型的鲁棒性,同时也能验证模型测试方法的有效性,并且交叉验证也能够针对数据不足的情况弥补数据不足的缺陷。一般来说,我们常用的是K-fold Cross Validation,即K折交叉验证,将数据集均匀的分为K份,选其中一份作为验证集,剩下的几份作为训练集。但是,由于本次数据集有比较高的内聚性,数据集中的数据都为围绕一些话题的医疗查询句子对。如果不针对每个话题均匀的划分K份数据,会发生训练的数据与测试数据不隶属于同一话题而导致验证误差偏大,这将导致训练失效。因此需要针对不同话题均匀的划分数据集,即保证对于每个话题,训练集和验证集中的比例是大致相等的。通过观察数据集,我们发现数据集都比较有规律,即已经按照话题进行排序。因此,我们在K折交叉验证中对每一折的选取采用了对K取模,按照余数分组的方法,即保证对于每个折,都是在原数据集中以K为步长均匀选出来的。为了保证模型的鲁棒性,在完成所有的折的拆分后,我们会对所有数据进行随机排序,并且我们定义整体的准确度Accuracy为所有折的平均值。具体的代码实现如图 43所示。

图 43 K折交叉验证

    1. 模型融合

在完成各个模型的训练后,我们使用模型融合。这里我们使用比较简单的模型加权融合。在完成训练后,根据我们设置好的权重,对模型分类的概率进行加权并归一化得到最终的加权概率。我们将最终加权的概率作为最终的结果,这种方法可以集成融合多个不同模型的优势。具体实现代码如图 44所示。

图 44 模型融合

       在此部分中,我们先加载各个已经训练好并且在磁盘上持久化的模型,使用测试数据进行推理得到最终预测的概率。接着,我们根据提前设置好的权重对概率进行加权求和并归一化后得到最终的预测结果。

  1. 实验结果

经过上一部分中的模型搭建与训练,可以得出最终的结果,我将结果提交到阿里云天池平台上,可以得到最终结果。最终,我的线上得分Accuracy0.887,在22979个参赛选手中排名第1,位次比0.004%。最终名次,提交结果以及最终排行榜截图如图 45,图 46和图 47所示。

图 45 提交名次与成绩图

图 46 提交结果图

图 47 总排行榜

  1. 总结与展望

本项目使用数据挖掘技术中的自然语言处理技术对医学搜索相关性进行判断。医疗问题是全球范围内的重要问题之一。医疗搜索的相关性也是医生与患者间的桥梁之一,高效且准确的分类查询工具对减轻医生负担、为患者提供及时救助有至关重要的作用。

本项目呈现了任务分析、数据清洗、数据预处理、模型建立、模型评估等数据挖掘一般流程。

针对数据清洗和数据预处理,我们对原始数据进行可视化,发现数据中存在的异常(如标签分布不均匀等),根据标签的实际逻辑,使用数据增强等技术对数据进行了预处理。对于文本不能直接入模的问题,我们使用分词将中文文本转化成了能入模的形式。

针对模型建立和模型评估过程,考虑到数据的数量少,不平衡等特点我采用“通用预训练模型”,“领域预训练模型”,“调参调优”,“模型融合”的四阶段范式,并选择了3个优势互补的通用预训练模型和2个领域专用预训练模型。此外,我还通过对抗性权重扰动对抗训练、随机权重衰减、设置分层学习率、Multi-Sample Dropout、K折交叉验证等技术增强模型的鲁棒性。通过集成融合不同优势的多个模型以及两阶段的伪标签学习,来对结果进行预测,最终取得线上Accuracy 0.887的成绩,在22979个参赛选手中排名第1,位次比0.004%

未来,我们可以在数据增强,模型选择或模型融合等方面进行更多样化的探究,采用简单的随机替换词语、交换、删除、回译等操作引入一些噪声,进一步提升模型的泛化与抗干扰能,选择更合理的预训练模型,采用更合理的模型融合方法(如Stacking)等,以进一步提升模型的性能。

最后,非常感谢陈老师本学期《数据挖掘》的授课。我是计算机体系结构方向的研究生,本科时也选修过陈老师的《数据挖掘导论》课程,之前参加过一些数据挖掘的算法比赛,对数据挖掘、机器学习和深度学习有关一些了解。本学期的课程,进一步开拓了我的视野,为我带来了更多更前沿的数据挖掘相关知识。计算机体系结构未来的发展趋势,也会和数据挖掘与学习的浪潮相碰撞,如何在系统层面对上层AI应用进行特异性的优化(如面向AI应用的存算融合技术)亦或是将AI方法结合到系统的设计哲学中(如基于应用访存模式识别的数据预取优化技术)是系统领域现阶段研究的重大课题。相信这门课的学习,可以促使我多加思考交叉领域的问题启发自己的研究思路,提高自身的学科素养。

  1. 参考文献
  1. Zhang N, Chen M, Bi Z, et al. Cblue: A chinese biomedical language understanding evaluation benchmark[J]. arXiv preprint arXiv:2106.08087, 2021.
  2. Chinese-BERT-wwm : https://github.com/ymcui/Chinese-BERT-wwm
  3. CLUE Pretrained Models : https://github.com/CLUEbenchmark/CLUEPretrainedModels
  4. Bert-Chinese-Text-Classification-Pytorch : https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch
  5. ERNIE-Pytorch Pretrained Models: https://github.com/nghuyong/ERNIE-Pytorch
  6. Dong Y, Deng Z, Pang T, et al. Adversarial distributional training for robust deep learning[J]. Advances in Neural Information Processing Systems, 2020, 33: 8270-8283.
  7. Izmailov P, Podoprikhin D, Garipov T, et al. Averaging weights leads to wider optima and better generalization[J]. arXiv preprint arXiv:1803.05407, 2018.
  8. LTP: http://ltp.ai/
  9. chip2019: https://biendata.com/competition/chip2019/
  10. Inoue H. Multi-sample dropout for accelerated training and better generalization[J]. arXiv preprint arXiv:1905.09788, 2019.
  11. Devlin J, Chang M W, Lee K, et al. Bert: Pre-training of deep bidirectional transformers for language understanding[J]. arXiv preprint arXiv:1810.04805, 2018.
  12. Wu D, Xia S T, Wang Y. Adversarial weight perturbation helps robust generalization[J]. Advances in Neural Information Processing Systems, 2020, 33: 2958-2969.

http://www.ppmy.cn/embedded/151293.html

相关文章

从零开始开发纯血鸿蒙应用之UI封装

从零开始开发纯血鸿蒙应用 一、题引二、UI 组成三、UI 封装原则四、实现 lib_comps1、封装 UI 样式1.1、attributeModifier 属性1.2、自定义AttributeModifier<T>类 2、封装 UI 组件 五、总结 一、题引 在开始正文前&#xff0c;为了大家能够从本篇博文中&#xff0c;汲…

电化学气体传感器在物联网中的精彩表现

电化学气体传感器是一种基于电化学原理来检测气体浓度的设备。它通过测量目标气体在电极处氧化或还原而产生的电流或电势信号&#xff0c;从而推算出气体的浓度。这种工作原理使得它可被用于探测大多数的有毒有害气体&#xff0c;包括一氧化碳、硫化氢、氯气、二氧化硫等。和武…

APM for Large Language Models

APM for Large Language Models 随着大语言模型&#xff08;LLMs&#xff09;在生产环境中的广泛应用&#xff0c;确保其可靠性和可观察性变得至关重要。应用性能监控&#xff08;APM&#xff09;在这一过程中发挥了关键作用&#xff0c;帮助开发者和运维人员深入了解LLM系统的…

快速了解缓存穿透与缓存雪崩

在缓存系统的使用过程中&#xff0c;缓存穿透和缓存雪崩是两种常见的问题&#xff0c;它们会导致缓存失效&#xff0c;从而对系统性能造成影响。下面我将快速介绍这两个问题及其解决方法。 1. 缓存穿透 (Cache Penetration) 缓存穿透是指客户端请求的某些数据&#xff0c;既不…

LeetCode算法题——移除元素

题目描述 给你一个数组 nums 和一个值 val&#xff0c;你需要原地移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。 假设 nums 中不等于 val 的元素数量为 k&#xff0c;要通过此题&#xff0c;您需要执行以下操作&#xff1…

【Python】selenium结合js模拟鼠标点击、拦截弹窗、鼠标悬停方法汇总(使用 execute_script 执行点击的方法)

我们在写selenium获取网络信息的时候&#xff0c;有时候我们会受到对方浏览器的监控&#xff0c;对方通过分析用户行为模式&#xff0c;如点击、滚动、停留时间等&#xff0c;网站可以识别出异常行为&#xff0c;进而对Selenium爬虫进行限制。 这里我们可以加入JavaScript的使…

HTML——16.相对路径

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title></title></head><body><a href"../../fj1/fj2/c.html" target"_blank">链接到c</a><!--相对路径&#xff1a;-->…

Flume的安装和使用

一、安装Flume 1. 下载flume-1.7.0 http://mirrors.shu.edu.cn/apache/flume/1.7.0/apache-flume-1.7.0-bin.tar.gz 2. 解压改名 tar xvf apache-flume-1.7.0-bin.tar.gz mv apache-flume-1.7.0-bin flume 二、配置Flume 1. 配置sh文件 cp conf/flume-env.sh.template …