BERT论文解读及实现(二)

news/2024/11/20 0:27:10/

基于github的bert源码解读

bert github链接: https://github.com/google-research/bert/tree/master
windows流程运行改编版源码及数据百度网盘链接:链接:https://pan.baidu.com/s/1APk9EIh_wuU41fMHSMz3Pg?pwd=s2k7 
提取码:s2k7 

1 模型预训练

1.1训练数据加工

python create_pretraining_data.py--input_file=./sample_text.txt \--output_file=../GLUE/chineseoutput/tf_examples.tfrecord \--vocab_file=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/vocab.txt \--do_lower_case=True \--max_seq_length=128 \--max_predictions_per_seq=20 \--masked_lm_prob=0.15 \--random_seed=12345 \--dupe_factor=5

1.1.1 输入文本格式

输入文本格式为,每行一个句子,文档之间使用空行进行分割
line = line.strip() 空行执行到此为空,会创建一个新的文档list
all_documents.append([])

    with tf.gfile.GFile(input_file, "r") as reader:while True:line = tokenization.convert_to_unicode(reader.readline())if not line:break# 文档之间使用空行标识,line = line.strip()# Empty lines are used as document delimitersif not line:# 空行新建[],用于存放新文档的每个句子all_documents.append([])tokens = tokenizer.tokenize(line)if tokens:all_documents[-1].append(tokens)

1.1.2 create_training_instances

  • 读取所有文档并分词,打散文档
  • 使用create_instances_from_document 处理每一个文档

1.1.3 create_instances_from_document

  • 处理单个文档数据,
  • 最大tokens数为 ,最大序列长度 - 3,因为存在3个特色token , [CLS], [SEP], [SEP].
max_num_tokens = max_seq_length - 3
  • 10%的比例使用短句子。
    (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
  • 句子a,和句子b,根据句子长度使用一个文档中多个句子拼接,
    比如文档前四个句子长度才到128,可以使用1、2句子作为a句子,2,3句子作为b句。此操作, is_random_next = Flase
  • 当随机概率小于0.5时,会随机从其他文档中选择句子作为下一句,此时is_random_next = True
  target_seq_length = max_num_tokensif rng.random() < short_seq_prob:target_seq_length = rng.randint(2, max_num_tokens)

1.1.4 create_masked_lm_predictions

  • 去掉句子里的[CLS]、 [SEP]
  • 提取句子里其他token的id,然后随机打散(前15%的id会被mask掉)
  • 使用下面的mask机制,将前15%的数量的tokens mask掉。并记录mask的位置和原本对应的 token。

mask 机制

    for index in index_set:covered_indexes.add(index)masked_token = None# 80% of the time, replace with [MASK]if rng.random() < 0.8:masked_token = "[MASK]"else:# 10% of the time, keep originalif rng.random() < 0.5:masked_token = tokens[index]# 10% of the time, replace with random wordelse:masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]output_tokens[index] = masked_token

1.2 模型训练

模型训练脚本:

python run_pretraining.py \--input_file=../GLUE/chineseoutput/tf_examples.tfrecord \--output_dir=../GLUE/pretraining_output \--do_train=True \--do_eval=True \--bert_config_file=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_config.json \--init_checkpoint=../GLUE/BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_model.ckpt \--train_batch_size=32 \--max_seq_length=128 \--max_predictions_per_seq=20 \--num_train_steps=20 \--num_warmup_steps=10 \--learning_rate=2e-5

1.2.1 模型输入

  • 模型输入特征及形状如下,其中32为batch_size,128为句子长度
    dict_keys([‘input_ids’, ‘input_mask’, ‘masked_lm_ids’, ‘masked_lm_positions’, ‘masked_lm_weights’, ‘next_sentence_labels’, ‘segment_ids’])
name = input_ids, shape = (32, 128)
name = input_mask, shape = (32, 128)
name = masked_lm_ids, shape = (32, 20)
name = masked_lm_positions, shape = (32, 20)
name = masked_lm_weights, shape = (32, 20)name = next_sentence_labels, shape = (32, 1)
name = segment_ids, shape = (32, 128)

1.2.2 模型结构

embedding 层
embedding_lookup
① word embedding
输出 embedding_table 维度shape=[30522, 768]
输出 embedding_output shape=[32,128,768]
② embedding_postprocessor
位置 embedding和token_type_ids embedding
token_type_ids embedding 词典为 shape =[2,768]
position embedding 词典大小为 shape=[512,768]

embedding 层输出为:word embedding + token_type_ids embedding + position embedding
shape=[32,128,768]

transformer 层
输入为 embedding层的结果: shape=[32,128,768]
输出为经过transformer 层的结果: shape=[32,128,768]

pooled层
取transformer输出结果中,句子维度第一个数,即[CLS]所在的位置。
然后再过一层全连接层,输出 shape=[32,768]

first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(first_token_tensor,config.hidden_size,activation=tf.tanh,kernel_initializer=create_initializer(config.initializer_range))

get_masked_lm_output
① 提取mask位置的数据
name = masked_lm_ids, shape = (32, 20)
输入为transform最后一层数据 shape=[32,128,768]
根据mask所在位置索引,提取mask后的数据,shape=[32,20,768],将数据展平,shape=[640,768]
② 将上面数据shape=[640,768] 再经过一层全连接层,和layer normlization层,输出shape=[640,768]
③ 将上面层与字典table embedding 进行计算,计算,shape=[640,30522],然后再过一个softmax,即可得到,在字典表中的概率。

get_next_sentence_output
输入为CLS对应的数据,即上面pooled层输出,shape =[32,768]
经过全连接层后输出shape =[2,768]
然后经过softmax,得到分类概率 shape =[2,768]

1.2.3 损失函数

  • Mask 损失:
    将上面get_masked_lm_output 与真是label值进行交叉熵损失计算。
    per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])numerator = tf.reduce_sum(label_weights * per_example_loss)
  • next_sentence_loss
   log_probs = tf.nn.log_softmax(logits, axis=-1)labels = tf.reshape(labels, [-1])one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)loss = tf.reduce_mean(per_example_loss)
  • 整体损失
   total_loss = masked_lm_loss + next_sentence_loss

2 模型下游任务微调

未完待续。。。。。。


http://www.ppmy.cn/news/803646.html

相关文章

微信小程序设置 tabbar icon 大小

微信小程序的 tabbar icon 大小可以通过以下方式进行设置&#xff1a; 将 tabbar icon 图片制作成合适的尺寸&#xff1a;你可以使用设计工具&#xff08;如 Photoshop、Sketch 等&#xff09;将图标调整为合适的大小。通常建议使用 48x48 或 60x60 像素的图标。 在 app.json …

OpenCV 入门教程:目标检测与跟踪概念

OpenCV 入门教程&#xff1a;目标检测与跟踪概念 导语一、目标检测与跟踪概述二、目标检测与跟踪方法1.1 基于特征的方法1.2 学习-based 方法1.3 基于滤波器的方法1.4 基于深度学习的方法 三、目标检测与跟踪实例总结 导语 目标检测与跟踪是计算机视觉领域的重要任务&#xff…

资深影迷不可不知的宽高比:Aspect Ratio 电影画面比例

我们看到的电影 电影胶片 电影在开始放映之前&#xff0c;大家看到的肯定是一块白色的银幕&#xff08;如果没看见那是有幕布挡着&#xff09;。电影开映后&#xff0c;一束强光打在银幕上&#xff0c;画面出现在银幕上&#xff0c;于是电影就开始了。 这个看似很神奇…

苹果谷歌脸书大佬前往游说!欧盟将首次对AI进行监管

1、欧盟将首次对AI进行监管&#xff0c;苹果谷歌脸书大佬前往游说 谷歌母公司Alphabet CEO桑达尔皮查伊(SundarPichai)、苹果负责人工智能业务的高级副总裁约翰詹南德雷亚&#xff08;John Giannandrea&#xff09;近期都访问了布鲁塞尔。 星期一&#xff0c;脸书CEO马克扎克伯…

ES6: Symbol概念与用法举例

概念: ES6 引入了一种新的原始数据类型Symbol&#xff0c;表示独一无二的值。 1-使用Symbol作为对象属性名 let name Symbol() let age Symbol() var obj {[name]:"kerwin",[age]:100 }举例理解: a.给对象添加独一无二的属性 let obj {name: Jack }let name …

Zynq 多个UDP客户端组网启动问题(Auto negotiation error)PS:附UDP客户端初始化代码

最近正在进行一个Zynq项目&#xff0c;根据设计需求&#xff0c;需要将上位机作为UDP服务器&#xff0c;而FPGA则充当UDP客户端。同时&#xff0c;服务器需要能够接收和控制多个UDP客户端。 开发过程中&#xff0c;我是基于lwip UDP Perf Client 官方模版开发的。我遇到了以下几…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java自助旅游平台v294n

毕业设计说实话没有想象当中的那么难&#xff0c;导师也不会说刻意就让你毕设不通过&#xff0c;不让你毕业啥的&#xff0c;你只要不是太过于离谱的&#xff0c;都能通过的。首先你得要对你在大学期间所学到的哪方面比较熟悉&#xff0c;语言比如JAVA、PHP等这些&#xff0c;数…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java足球爱好者服务平台i387z

毕业设计说实话没有想象当中的那么难&#xff0c;导师也不会说刻意就让你毕设不通过&#xff0c;不让你毕业啥的&#xff0c;你只要不是太过于离谱的&#xff0c;都能通过的。首先你得要对你在大学期间所学到的哪方面比较熟悉&#xff0c;语言比如JAVA、PHP等这些&#xff0c;数…