预训练策略
BERT模型的预训练基于两个任务:
屏蔽语言建模
下一句预测
在深入屏蔽语言建模之间,我们先来理解一下语言建模任务的原理。
语言建模
在语言建模任务中,我们训练模型给定一系列单词来预测下一个单词。可以把语言建模分为两类:
自回归语言建模
自编码语言建模
自回归语言建模
我们还可以将自回归语言建模归类为:
前向(左到右)预测
反向(右到左)预测
通过实例来理解。考虑文本Paris is a beautiful city. I love Paris。假设我们移除了单词city然后替换为空白符__:
Paris is a beautiful __. I love Paris
现在,我们的模型需要预测空白符实际的单词。如果使用前向预测,那么我们的模型以从左到右的顺序阅读序列中的单词,直到空白符:
Paris is a beautiful __.
如果我们使用反向预测,那么我们的模型以从右到左的顺序阅读序列中的单词,直到空白符:
__. I love Paris paris love i ,_?
因此,自回归模型天然就是单向的,意味着它们只会以一个方向阅读输入序列。
自编码语言建模
自编码语言建模任务同时利用了前向(左到右)和反向(右到左)预测的优势。即,它们在预测时同时读入两个方向的序列。因此,我们可以说自编码语言模型天生就是双向的。
为了预测空白符,自编码语言模型同时从两个方向阅读序列,如下所示:
Paris is a beautiful __. I love Paris
因此双向的模型能获得更好的结果。
屏蔽语言建模
BERT是一个自编码语言模型,即预测时同时从两个方向阅读序列。在一个屏蔽语言建模任务中,对于给定的输入序列,我们随机屏蔽15%的单词,然后训练模型去预测这些屏蔽的单词。为了做到这一点,我们的模型以两个方向读入序列然后尝试预测屏蔽的单词。
举个例子。我们考虑上面见到的句子:Paris is a beautiful city', and 'I love Paris。首先,我们将句子分词,得到一些标记:
tokens = [Paris, is, a beautiful, city, I, love, Paris]
接着还是增加[CLS]标记到第一个句子的开头,增加[SEP]标记到每个句子的结尾:
tokens = [ [CLS], Paris, is, a beautiful, city, [SEP], I, love, Paris, [SEP] ]
接下来,我们在上面的标记列表中随机地屏蔽15%的标记(单词)。假设我们屏蔽单词city,然后用[MASK]标记替换这个单词,结果为:
tokens = [ [CLS], Paris, is, a beautiful, [MASK], [SEP], I, love, Paris, [SEP] ]
现在训练我们的BERT模型去预测被屏蔽的标记。
这里有一个小问题。 以这种方式屏蔽标记会在预训练和微调之间产生差异。即,我们训练BERT通过预测[MASK]标记。训练完之后,我们可以为下游任务微调预训练的BERT模型,比如情感分析任务。但在微调期间,我们的输入不会有任何的[MASK]标记。因此,它会导致 BERT 的预训练方式与微调方式不匹配。
为了解决这个问题,我们应用80-10-10%规则。我们知道我们会随机地屏蔽句子中15%的标记。现在,对于这些15%的标记,我们做下面的事情:
80%的概率,我们用[MASK]标记替换该标记。因此,80%的情况下,输入会变成如下:
tokens = [ [CLS], Paris, is, a beautiful, [MASK], [SEP], I, love, Paris, [SEP] ]
10%的概率,我们用一个随机标记(单词)替换该标记。所以,10%的情况下,输入变为:
tokens = [ [CLS], Paris, is, a beautiful, love, [SEP], I, love, Paris, [SEP] ]
剩下10%的概率,我们不做任何替换。因此,此时输入不变:
tokens = [ [CLS], Paris, is, a beautiful, city, [SEP], I, love, Paris, [SEP] ]
在分词和屏蔽之后,我们分别将这些输入标记喂给标记嵌入、片段嵌入和位置嵌入层,然后得到输入嵌入。
然后,我们将输入嵌入喂给BERT。如下所示,BERT接收输入然后返回每个标记的嵌入表示作为输出。 R_[CLS] 代表输入标记[CLS]的嵌入表示, R_Paris 代表标记Paris的嵌入表示,以此类推。
在本例中,我们使用BERT-base,即有12个编码器层,12个注意力头和768个隐藏单元。
我们得到了每个标记的嵌入表示R。现在, 我们如何用这些表示预测屏蔽的标记?
为了预测屏蔽的标记,我们将BERT返回的屏蔽的单词表示 R_[MASK]喂给一个带有softmax激活函数的前馈神经网络。然后该网络输出词表中每个单词属于该屏蔽的单词的概率。如下图所示。这里输入嵌入层没有画出来以减小版面:
从上图可以看到,单词city属于屏蔽单词的概率最高。因此,我们的模型会预测屏蔽单词为city。
注意在初始的迭代中,我们的模型不会输出正确的概率,因为前馈网络和BERT编码器层的参数还没有被优化。然而,通过一系列的迭代之后,我们更新了前馈网络和BERT编码器层的参数,然后学到了优化的参数。
屏蔽语言建模也被称为完形填空(cloze)任务。我们已经知道了如何使用屏蔽语言建模任务训练BERT模型。而屏蔽输入标记时,我们也可以使用一个有点不同的方法,叫作全词屏蔽(whole word masking,WWM)。
全词屏蔽
同样,我们以实例来理解全词屏蔽是如何工作的。考虑句子Let us start pretraining the model。记住BERT使用WordPiece分词器,所以,在使用该分词器之后,我们得到下面的标记:
tokens = [let, us, start, pre, ##train, ##ing, the, model]
然后增加[CLS]和[SEP]标记:
tokens = [[CLS], let, us, start, pre, ##train, ##ing, the, model, [SEP]]
接着随机屏蔽15%的单词。假设屏蔽后的结果为:
tokens = [[CLS], [MASK], us, start, pre, [MASK], ##ing, the, model, [SEP]]
从上面可知,我们屏蔽了单词let和##train。其中##train是单词pretraining的一个子词。在全词屏蔽模型中,如果子词被屏蔽了,然后我们屏蔽与该子词对应单词的所有子词。因此,我们的标记变成了下面的样子:
tokens = [[CLS], [MASK], us, start, [MASK], [MASK], [MASK], the, model, [SEP]]
注意我们也需要保持我们的屏蔽概率为15%。所以,当屏蔽子词对应的所有单词后,如果超过了15%的屏蔽率,我们可以取消屏蔽其他单词。如下所示,我们取消屏蔽单词let来控制屏蔽率:
tokens = [[CLS], let, us, start, [MASK], [MASK], [MASK], the, model, [SEP]]
这样,我们使用全词屏蔽来屏蔽标记。
下一句预测
下一句预测(next sentence prediction,NSP)是另一个用于训练BERT模型的任务。NSP是二分类任务,在此任务中,我们输入两个句子,然后BERT需要判断第二个句子是否为第一个句子的下一句。
考虑下面两个句子:
Sentence A: She cooked pasta.
Sentence B: It was delicious.
这两个句子中,B就是A的下一句,所以我们标记这对句子为isNext。
然后看另外两个句子:
Sentence A: Turn the radio on.
Sentence B: She bought a new hat.
显然B不是A的下一句,所以我们标记这个句子对为notNext。
在NSP任务中,我们模型的目标是预测句子对属于isNext还是notNext。
那么NSP任务有什么用?通过运行NSP任务,我们的模型可以理解两个句子之间的关系,这会有利于很多下游任务,像问答和文本生成。
那么如何获取NSP任务的数据集?我们可以从任何单语语料库中生成数据集。假设我们有一些文档。对于isNext类别,我们从某篇文档中抽取任意相连的句子,然后将它们标记为isNext;对于notNext类别,我们从一篇文档中取一个句子,然后另一个句子随机的从所有文档中取,标记为notNext。同时我们需要保证数据集中50%的句子对属于isNext,剩下50%的句子对属于notNext。
假设我们这样得到如下所示的数据集:
我们以上面数据集中第一个句子对为例。首先,我们进行分词,得到:
tokens = [She, cooked, pasta, It, was, delicious]
接下来,增加[CLS]和SEP标记:
tokens = [[CLS], She, cooked, pasta, [SEP], It, was, delicious, [SEP]]
然后我们把这个输入标记喂给标记嵌入、片段嵌入和位置嵌入层,得到输入嵌入。
接着把输入嵌入喂给BERT获得每个标记的嵌入表示。如下图所示, R_[CLS] 代表标记[CLS]的嵌入表示。
为了进行分类,我们简单地将[CLS]标记的嵌入表示喂给一个带有softmax函数的全连接网络,该网络会返回我们输入的句子对属于isNext和notNext的概率。
因为[CLS]标记保存了所有标记的聚合表示,也就得到了整个输入的信息?。所以我们可以直接拿该标记对应的嵌入表示来进行预测。如下图所示:
上面我们可以看到,最终的全连接网络输出isNext的概率较高。
参考资料
Getting Started with Google BERT
https://blog.csdn.net/yjw123456/article/details/120211601