跟着问题学19——BERT详解(2)

news/2024/12/18 13:14:28/

预训练策略

BERT模型的预训练基于两个任务:

    屏蔽语言建模

    下一句预测

在深入屏蔽语言建模之间,我们先来理解一下语言建模任务的原理。

语言建模

在语言建模任务中,我们训练模型给定一系列单词来预测下一个单词。可以把语言建模分为两类:

    自回归语言建模

    自编码语言建模

自回归语言建模

我们还可以将自回归语言建模归类为:

    前向(左到右)预测

    反向(右到左)预测

   通过实例来理解。考虑文本Paris is a beautiful city. I love Paris。假设我们移除了单词city然后替换为空白符__:

Paris is a beautiful __. I love Paris

现在,我们的模型需要预测空白符实际的单词。如果使用前向预测,那么我们的模型以从左到右的顺序阅读序列中的单词,直到空白符:

Paris is a beautiful __.

如果我们使用反向预测,那么我们的模型以从右到左的顺序阅读序列中的单词,直到空白符:

__. I love Paris     paris love i ,_?

因此,自回归模型天然就是单向的,意味着它们只会以一个方向阅读输入序列。

自编码语言建模

自编码语言建模任务同时利用了前向(左到右)和反向(右到左)预测的优势。即,它们在预测时同时读入两个方向的序列。因此,我们可以说自编码语言模型天生就是双向的。

为了预测空白符,自编码语言模型同时从两个方向阅读序列,如下所示:

Paris is a beautiful __. I love Paris

因此双向的模型能获得更好的结果。

屏蔽语言建模

BERT是一个自编码语言模型,即预测时同时从两个方向阅读序列。在一个屏蔽语言建模任务中,对于给定的输入序列,我们随机屏蔽15%的单词,然后训练模型去预测这些屏蔽的单词。为了做到这一点,我们的模型以两个方向读入序列然后尝试预测屏蔽的单词。

举个例子。我们考虑上面见到的句子:Paris is a beautiful city', and 'I love Paris。首先,我们将句子分词,得到一些标记:

tokens = [Paris, is, a beautiful, city, I, love, Paris]

接着还是增加[CLS]标记到第一个句子的开头,增加[SEP]标记到每个句子的结尾:

tokens = [ [CLS], Paris, is, a beautiful, city, [SEP], I, love, Paris, [SEP] ]

接下来,我们在上面的标记列表中随机地屏蔽15%的标记(单词)。假设我们屏蔽单词city,然后用[MASK]标记替换这个单词,结果为:

tokens = [ [CLS], Paris, is, a beautiful, [MASK], [SEP], I, love, Paris, [SEP] ]

现在训练我们的BERT模型去预测被屏蔽的标记。

这里有一个小问题。 以这种方式屏蔽标记会在预训练和微调之间产生差异。即,我们训练BERT通过预测[MASK]标记。训练完之后,我们可以为下游任务微调预训练的BERT模型,比如情感分析任务。但在微调期间,我们的输入不会有任何的[MASK]标记。因此,它会导致 BERT 的预训练方式与微调方式不匹配。

为了解决这个问题,我们应用80-10-10%规则。我们知道我们会随机地屏蔽句子中15%的标记。现在,对于这些15%的标记,我们做下面的事情:

    80%的概率,我们用[MASK]标记替换该标记。因此,80%的情况下,输入会变成如下:

    tokens = [ [CLS], Paris, is, a beautiful, [MASK], [SEP], I, love, Paris, [SEP] ]

   

    10%的概率,我们用一个随机标记(单词)替换该标记。所以,10%的情况下,输入变为:

    tokens = [ [CLS], Paris, is, a beautiful, love, [SEP], I, love, Paris, [SEP] ]

   

    剩下10%的概率,我们不做任何替换。因此,此时输入不变:

    tokens = [ [CLS], Paris, is, a beautiful, city, [SEP], I, love, Paris, [SEP] ]

   

在分词和屏蔽之后,我们分别将这些输入标记喂给标记嵌入、片段嵌入和位置嵌入层,然后得到输入嵌入。

然后,我们将输入嵌入喂给BERT。如下所示,BERT接收输入然后返回每个标记的嵌入表示作为输出。 R_[CLS] 代表输入标记[CLS]的嵌入表示, R_Paris 代表标记Paris的嵌入表示,以此类推。

在本例中,我们使用BERT-base,即有12个编码器层,12个注意力头和768个隐藏单元。

我们得到了每个标记的嵌入表示R。现在, 我们如何用这些表示预测屏蔽的标记?

为了预测屏蔽的标记,我们将BERT返回的屏蔽的单词表示 R_[MASK]喂给一个带有softmax激活函数的前馈神经网络。然后该网络输出词表中每个单词属于该屏蔽的单词的概率。如下图所示。这里输入嵌入层没有画出来以减小版面:

从上图可以看到,单词city属于屏蔽单词的概率最高。因此,我们的模型会预测屏蔽单词为city。

注意在初始的迭代中,我们的模型不会输出正确的概率,因为前馈网络和BERT编码器层的参数还没有被优化。然而,通过一系列的迭代之后,我们更新了前馈网络和BERT编码器层的参数,然后学到了优化的参数。

屏蔽语言建模也被称为完形填空(cloze)任务。我们已经知道了如何使用屏蔽语言建模任务训练BERT模型。而屏蔽输入标记时,我们也可以使用一个有点不同的方法,叫作全词屏蔽(whole word masking,WWM)。

全词屏蔽

同样,我们以实例来理解全词屏蔽是如何工作的。考虑句子Let us start pretraining the model。记住BERT使用WordPiece分词器,所以,在使用该分词器之后,我们得到下面的标记:

tokens = [let, us, start, pre, ##train, ##ing, the, model]

然后增加[CLS]和[SEP]标记:

tokens = [[CLS], let, us, start, pre, ##train, ##ing, the, model, [SEP]]

接着随机屏蔽15%的单词。假设屏蔽后的结果为:

tokens = [[CLS], [MASK], us, start, pre, [MASK], ##ing, the, model, [SEP]]

从上面可知,我们屏蔽了单词let和##train。其中##train是单词pretraining的一个子词。在全词屏蔽模型中,如果子词被屏蔽了,然后我们屏蔽与该子词对应单词的所有子词。因此,我们的标记变成了下面的样子:

tokens = [[CLS], [MASK], us, start, [MASK], [MASK], [MASK], the, model, [SEP]]

注意我们也需要保持我们的屏蔽概率为15%。所以,当屏蔽子词对应的所有单词后,如果超过了15%的屏蔽率,我们可以取消屏蔽其他单词。如下所示,我们取消屏蔽单词let来控制屏蔽率:

tokens = [[CLS], let, us, start, [MASK], [MASK], [MASK], the, model, [SEP]]

这样,我们使用全词屏蔽来屏蔽标记。

下一句预测

下一句预测(next sentence prediction,NSP)是另一个用于训练BERT模型的任务。NSP是分类任务,在此任务中,我们输入两个句子,然后BERT需要判断第二个句子是否为第一个句子的下一句。

考虑下面两个句子:

Sentence A: She cooked pasta.

Sentence B: It was delicious.

这两个句子中,B就是A的下一句,所以我们标记这对句子为isNext。

然后看另外两个句子:

Sentence A: Turn the radio on.

Sentence B: She bought a new hat.

显然B不是A的下一句,所以我们标记这个句子对为notNext。

在NSP任务中,我们模型的目标是预测句子对属于isNext还是notNext。

那么NSP任务有什么用?通过运行NSP任务,我们的模型可以理解两个句子之间的关系,这会有利于很多下游任务,像问答和文本生成。

那么如何获取NSP任务的数据集?我们可以从任何单语语料库中生成数据集。假设我们有一些文档。对于isNext类别,我们从某篇文档中抽取任意相连的句子,然后将它们标记为isNext;对于notNext类别,我们从一篇文档中取一个句子,然后另一个句子随机的从所有文档中取,标记为notNext。同时我们需要保证数据集中50%的句子对属于isNext,剩下50%的句子对属于notNext。

假设我们这样得到如下所示的数据集:

我们以上面数据集中第一个句子对为例。首先,我们进行分词,得到:

tokens = [She, cooked, pasta, It, was, delicious]

接下来,增加[CLS]和SEP标记:

tokens = [[CLS], She, cooked, pasta, [SEP], It, was, delicious, [SEP]]

然后我们把这个输入标记喂给标记嵌入、片段嵌入和位置嵌入层,得到输入嵌入。

接着把输入嵌入喂给BERT获得每个标记的嵌入表示。如下图所示, R_[CLS] 代表标记[CLS]的嵌入表示。

为了进行分类,我们简单地将[CLS]标记的嵌入表示喂给一个带有softmax函数的全连接网络,该网络会返回我们输入的句子对属于isNext和notNext的概率。

因为[CLS]标记保存了所有标记的聚合表示,也就得到了整个输入的信息?。所以我们可以直接拿该标记对应的嵌入表示来进行预测。如下图所示:

上面我们可以看到,最终的全连接网络输出isNext的概率较高。

参考资料

Getting Started with Google BERT

https://blog.csdn.net/yjw123456/article/details/120211601


http://www.ppmy.cn/news/1556128.html

相关文章

排序算法总结(python实现)

前言 排序算法是一类常见的算法,在学习算法的过程中,都会学习这些排序算法的实现。尽管现在大多数程序语言以及扩展包中对排序算法进行了封装,只要调用接口函数即可实现算法。学习和总结排序算法对于理解算法思维还是很有帮助的。因此本文在…

模拟登录网页

模拟登录与数据采集 今天我们讨论了如何通过 Python 模拟登录并抓取登录后的数据,主要涵盖了以下内容: 模拟登录步骤: 分析登录页面:使用浏览器开发者工具(F12)分析登录表单,提取表单字段、提…

【echarts】数据过多时可以左右滑动查看(可鼠标可滚动条)

1. 鼠标左右拖动 在和 series 同级的地方配置 dataZoom: dataZoom: [{type: inside, // inside 鼠标左右拖图表,滚轮缩放; slider 使用滑动条start: 0, // 左边的滑块位置,表示从 0 开始显示end: 60, // 右边的滑块位置&#xf…

企业车辆管理系统(源码+数据库+报告)

一、项目介绍 352.基于SpringBoot的企业车辆管理系统,系统包含两种角色:管理员、用户,系统分为前台和后台两大模块 二、项目技术 编程语言:Java 数据库:MySQL 项目管理工具:Maven 前端技术:Vue 后端技术&a…

python基础:(八)文件

目录 一.从文件中读取数据1.1读取整个文件1.2文件路劲1.3逐行读取 二.写入文件 一.从文件中读取数据 各位小伙伴,文件这一块得好好学,多看多敲代码,以后处理数据,写爬虫少不了这个,先从基础(简单的&#x…

Loadsh源码分析-filter,find,findLast,reject,partition

lodash源码研读之filter,find,findLast,reject,partition 一、源码地址 GitHub 地址: GitHub - lodash/lodash: A modern JavaScript utility library delivering modularity, performance, & extras.官方文档地址: Lodash 官方文档 二、结构分析 结构框图省略。 三、函…

css 实现呼吸灯效果

先看效果&#xff1a; 动画的结果就想实在呼吸,完整的代码如下&#xff1a; <template><div class"container"><div class"long-breath"></div></div> </template><style lang"less"> html, body{h…

什么是评价搭配

一、评价搭配的概念 评价搭配是指在文本中&#xff0c;由评价词&#xff08;如 “好”“坏”“优秀”“糟糕” 等表达主观意见的词&#xff09;和被评价对象&#xff08;如产品名称、服务类型、人物等&#xff09;组成的语义单元。例如&#xff0c;在 “这部手机的拍照效果很好…