simbert训练计划之踩坑盘点

news/2025/1/11 15:10:36/

朋友们,如果你的工作中需要用到语句相似度计算,可能听说过simbert这个模型,可能现实中你需要自己去训练某个专业的模型,里面还是有一些小坑需要你去踩的,下面盘点几个常见问题:

(1)显卡不适配

报错信息:failed to run cuBLAS routine cublasSgemm_v2: CUBLAS_STATUS_EXECUTION_FAILED Blas GEMM launch failed

这个问题很奇葩,但是你反复试过之后才会发现,你如果是用tf==1.14,那你的nvidia显卡最好用20系显卡,3060就别考虑了,会报这个错。

(2)cuda版本+cpu不适配

报错信息:keras/utils/data_utils.py:718: UserWarning: An input could not be retrieved. It could be because a worker has died.We do not have any information on the lost sample.
  UserWarning) 

 这个错误很恶心,找了很长时间,参考上面,你需要使用20系显卡,比如2080ti+tf-gpu==1.14+bertkeras=0.7.7,这个组合,然后就不会报错了

成功训练的配置&代码

bert4keras==0.7.7
Keras==2.3.1

tensorflow-gpu==1.15

 

#! -*- coding: utf-8 -*-
# SimBERT训练代码
# 训练环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.7.7from __future__ import print_function
import json
import numpy as np
from collections import Counter
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer, load_vocab
from bert4keras.optimizers import Adam, extend_with_weight_decay
from bert4keras.snippets import DataGenerator
from bert4keras.snippets import sequence_padding
from bert4keras.snippets import text_segmentate
from bert4keras.snippets import AutoRegressiveDecoder
from bert4keras.snippets import uniout# 基本信息
maxlen = 32
batch_size = 128
steps_per_epoch = 1000
epochs = 10000
corpus_path = 'data_sample.json'# bert配置
config_path = '/openbayes/home/chinese_simbert_L-6_H-384_A-12/bert_config.json'
checkpoint_path = '/openbayes/home/chinese_simbert_L-6_H-384_A-12/bert_model.ckpt'
dict_path = '/openbayes/home/chinese_simbert_L-6_H-384_A-12/vocab.txt'# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(dict_path=dict_path,simplified=True,startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)def read_corpus(corpus_path):"""读取语料,每行一个json"""while True:with open(corpus_path) as f:for l in f:yield json.loads(json.dumps(l))def truncate(text):"""截断句子"""seps, strips = u'\n。!?!?;;,, ', u';;,, 'return text_segmentate(text, maxlen - 2, seps, strips)[0]class data_generator(DataGenerator):"""数据生成器"""def __init__(self, *args, **kwargs):super(data_generator, self).__init__(*args, **kwargs)self.some_samples = []def __iter__(self, random=False):batch_token_ids, batch_segment_ids = [], []for is_end, d in self.sample(random):d = json.loads(d)text, synonyms = d['text'], d['synonyms']synonyms = [text] + synonymsnp.random.shuffle(synonyms)text, synonym = synonyms[:2]text, synonym = truncate(text), truncate(synonym)self.some_samples.append(text)if len(self.some_samples) > 1000:self.some_samples.pop(0)token_ids, segment_ids = tokenizer.encode(text, synonym, max_length=maxlen * 2)batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)token_ids, segment_ids = tokenizer.encode(synonym, text, max_length=maxlen * 2)batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)yield [batch_token_ids, batch_segment_ids], Nonebatch_token_ids, batch_segment_ids = [], []class TotalLoss(Loss):"""loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。"""def compute_loss(self, inputs, mask=None):loss1 = self.compute_loss_of_seq2seq(inputs, mask)loss2 = self.compute_loss_of_similarity(inputs, mask)self.add_metric(loss1, name='seq2seq_loss')self.add_metric(loss2, name='similarity_loss')return loss1 + loss2def compute_loss_of_seq2seq(self, inputs, mask=None):y_true, y_mask, _, y_pred = inputsy_true = y_true[:, 1:]  # 目标token_idsy_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分y_pred = y_pred[:, :-1]  # 预测序列,错开一位loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)return lossdef compute_loss_of_similarity(self, inputs, mask=None):_, _, y_pred, _ = inputsy_true = self.get_labels_of_similarity(y_pred)  # 构建标签y_pred = K.l2_normalize(y_pred, axis=1)  # 句向量归一化similarities = K.dot(y_pred, K.transpose(y_pred))  # 相似度矩阵similarities = similarities - K.eye(K.shape(y_pred)[0]) * 1e12  # 排除对角线similarities = similarities * 30  # scaleloss = K.categorical_crossentropy(y_true, similarities, from_logits=True)return lossdef get_labels_of_similarity(self, y_pred):idxs = K.arange(0, K.shape(y_pred)[0])idxs_1 = idxs[None, :]idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]labels = K.equal(idxs_1, idxs_2)labels = K.cast(labels, K.floatx())return labels# 建立加载模型
bert = build_transformer_model(config_path,checkpoint_path,with_pool='linear',application='unilm',keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表return_keras_model=False,
)encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
model = keras.models.Model(bert.model.inputs, outputs)AdamW = extend_with_weight_decay(Adam, 'AdamW')
optimizer = AdamW(learning_rate=2e-6, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
model.summary()class SynonymsGenerator(AutoRegressiveDecoder):"""seq2seq解码器"""@AutoRegressiveDecoder.set_rtype('probas')def predict(self, inputs, output_ids, step):token_ids, segment_ids = inputstoken_ids = np.concatenate([token_ids, output_ids], 1)segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)return seq2seq.predict([token_ids, segment_ids])[:, -1]def generate(self, text, n=1, topk=5):token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen)output_ids = self.random_sample([token_ids, segment_ids], n,topk)  # 基于随机采样return [tokenizer.decode(ids) for ids in output_ids]synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen
)def gen_synonyms(text, n=100, k=20):""""含义: 产生sent的n个相似句,然后返回最相似的k个。做法:用seq2seq生成,并用encoder算相似度并排序。效果:>>> gen_synonyms(u'微信和支付宝哪个好?')[u'微信和支付宝,哪个好?',u'微信和支付宝哪个好',u'支付宝和微信哪个好',u'支付宝和微信哪个好啊',u'微信和支付宝那个好用?',u'微信和支付宝哪个好用',u'支付宝和微信那个更好',u'支付宝和微信哪个好用',u'微信和支付宝用起来哪个好?',u'微信和支付宝选哪个好',]"""r = synonyms_generator.generate(text, n)r = [i for i in set(r) if i != text]r = [text] + rX, S = [], []for t in r:x, s = tokenizer.encode(t)X.append(x)S.append(s)X = sequence_padding(X)S = sequence_padding(S)Z = encoder.predict([X, S])Z /= (Z**2).sum(axis=1, keepdims=True)**0.5argsort = np.dot(Z[1:], -Z[0]).argsort()return [r[i + 1] for i in argsort[:k]]def just_show():"""随机观察一些样本的效果"""some_samples = train_generator.some_samplesS = [np.random.choice(some_samples) for i in range(3)]for s in S:try:print(u'原句子:%s' % s)print(u'同义句子:')print(gen_synonyms(s, 10, 10))print()except:passclass Evaluate(keras.callbacks.Callback):"""评估模型"""def __init__(self):self.lowest = 1e10def on_epoch_end(self, epoch, logs=None):model.save_weights('./latest_model.weights')# 保存最优if logs['loss'] <= self.lowest:self.lowest = logs['loss']model.save_weights('./best_model.weights')# 演示效果just_show()if __name__ == '__main__':train_generator = data_generator(read_corpus("./data_sample.json"), batch_size)evaluator = Evaluate()model.fit_generator(train_generator.forfit(),steps_per_epoch=steps_per_epoch,epochs=epochs,callbacks=[evaluator])else:model.load_weights('./latest_model.weights')

训练数据格式:

{"text": "有现金.住房公积金里有十万,可以买房吗", "synonyms": ["用住房公积金贷款买房,同时可以取出现金吗", "住房公积金买房时可以当现金首付吗", "如果现手上十五万现金,十万住房公积金,如何买房", "异地买房可以用住房公积金贷款吗?", "住房公积金贷款买房能现金提取吗?", "没现金 只有住房公积金怎么买房?"]}
{"text": "女方提出离婚吃亏在哪", "synonyms": ["女方提出离婚吃亏吗一", "女方提出离婚,我是不是是吃亏", "男方主动提出离婚吃亏还是女方主动提出离婚吃亏?", "女方先提出离婚会怎么样,先提出的吃亏", "谁先提出离婚谁吃亏吗", "女方主动提出离婚是否一定会吃亏?", "女方提出离婚要具备哪些条件", "女方向法院起诉离婚哪一方会吃亏"]}

 训练过程:

 相似度模型架构

Milvus:用于存储encode后的向量

simbert:用于向量化句子

post接口:用于获取句子,并返回相似度最高的接口

这是基本架构,以后有机会会把项目分享出来


http://www.ppmy.cn/news/809552.html

相关文章

支持裸耳3D空间音频?7月12日发布,荣耀Magic系列喜迎新成员

荣耀在7月12日将举办全场景新品发布会&#xff0c;其中将正式推出荣耀新款平板 MagicPad。 荣耀官方今天上午开始预热荣耀平板 MagicPad&#xff0c;官方海报文案表明这将成为首款支持裸耳3D空间音频的平板&#xff0c;引领行业潮流。 “空间音频技术”并不陌生&#xff0c;简…

edk2 security boot校验流程

edk2整体架构 关于安全校验的核心逻辑 Code\Edk2\MdeModulePkg\Universal\SecurityStubDxe\SecurityStub.c Status gBS->InstallMultipleProtocolInterfaces (&mSecurityArchProtocolHandle,&gEfiSecurity2ArchProtocolGuid,&mSecurity2Stub,&gEfiSecurit…

Java基础知识-泛型

Java基础知识-泛型 我将从以下几个方面总结泛型知识点&#xff1a; 一、概念 泛型的本质就是参数化类型&#xff08;就是说所操作的数据类型被指定为一个参数&#xff09;&#xff0c;这种参数类型可以用在类上、接口和方法上&#xff0c;分别被称为泛型类、泛型接口和泛型方…

西游记中金箍棒的来历和巨大威力

西游记中金箍棒的来历和巨大威力 我们知道西游记中孙悟空手中有一件非常厉害的兵器&#xff0c;能随心变化&#xff0c;使用起来非常方便&#xff0c;而且威力巨大&#xff0c;能轻轻一捱就把小妖压成肉饼&#xff0c;他就是原为定海神针的金箍棒。西游记中第三回&#xff0c;…

为什么孙悟空能大闹天宫,却打不过路上的妖怪?

“猴哥&#xff0c;我有个问题一直不咋明白啊。”八戒扛着钉耙&#xff0c;哼哧哼哧的跟在猴子后面&#xff0c;两只耳朵左右摆着风&#xff0c;依然止不住豆大的汗滴往下冒。 孙悟空放慢了步伐&#xff0c;一脸不耐烦的盯着猪八戒“你怎么老这么多事” 八戒看着猴子终于停下来…

西游论道--未闹天宫先闹龙宫

原文地址&#xff1a;西游论道--未闹天宫先闹龙宫 作者&#xff1a;西游记叔叔 一切随缘&#xff01;一切随心&#xff01; 万事万物莫不是因缘际会&#xff0c;所以要随缘、惜缘&#xff01;人身难得&#xff0c;真理难闻&#xff01;能进来本博是一种缘分&#xff0c;能听听西…

十二天宫起始篇(不要小看你的星座)

[sizemedium] 先来小谈一下星座的起源&#xff0c;引用百度百科的话来说就是星座是指占星学中必不可少的组成部分之一&#xff0c;也是天上一群群的恒星组合。自从古代以来&#xff0c;人类便把三五成群的恒星与他们神话中的人物或器具联系起来&#xff0c;称之为“星座”。它起…

孙悟空为什么能大闹天宫,而打不过诸多妖精?(转贴)

作者: 羽扇冠金(2004-09-06 20:21:20.0) 《西游记》中有这样一个矛盾,我以前一直百思不得其解:齐天大圣孙悟空在大闹天宫时战无不胜,要不是如来佛祖出手相援,整个天庭简直面临"亡国"的危险,但是在取经途中,大圣却好像很难一帆风顺,当年的那帮手下败将以及败将的跟班…