基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

news/2025/3/30 7:21:58/

  1. 数据预处理

    • 使用DataGenerator类加载并预处理数据,处理变长序列的padding。
    • 输入为内容(content),目标为标题(title)。
  2. 模型构建

    • 基于BERT构建Seq2Seq模型,使用交叉熵损失。
    • 采用Beam Search进行生成,支持Top-K采样。
  3. 训练与评估

    • 使用Adam优化器进行训练。
    • 每个epoch结束时通过Evaluate回调生成示例标题以观察效果。
import numpy as np
import pandas as pd
from tqdm import tqdm
from bert4keras.bert import build_bert_model
from bert4keras.tokenizer import Tokenizer, load_vocab
from keras.layers import *
from keras.models import Model
from keras import backend as K
from bert4keras.snippets import parallel_apply
from keras.optimizers import Adam
import keras
import math
from sklearn.model_selection import train_test_split
from rouge import Rouge  # 需要安装rouge包# 配置参数
config_path = 'bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'bert/chinese_L-12_H-768_A-12/vocab.txt'max_input_len = 256
max_output_len = 32
batch_size = 16
epochs = 10
beam_size = 3
learning_rate = 2e-5
val_split = 0.1# 数据预处理增强
class DataGenerator(keras.utils.Sequence):def __init__(self, data, batch_size=8, mode='train'):self.batch_size = batch_sizeself.mode = modeself.data = dataself.indices = np.arange(len(data))def __len__(self):return math.ceil(len(self.data) / self.batch_size)def __getitem__(self, index):batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]batch = self.data.iloc[batch_indices]return self._process_batch(batch)def on_epoch_end(self):if self.mode == 'train':np.random.shuffle(self.indices)def _process_batch(self, batch):batch_x, batch_y = [], []for _, row in batch.iterrows():content = row['content'][:max_input_len]title = row['title'][:max_output_len-2]  # 保留空间给[CLS]和[SEP]# 编码器输入x, _ = tokenizer.encode(content, max_length=max_input_len)# 解码器输入输出y, _ = tokenizer.encode(title, max_length=max_output_len)decoder_input = [tokenizer._token_start_id] + y[:-1]decoder_output = ybatch_x.append(x)batch_y.append({'decoder_input': decoder_input, 'decoder_output': decoder_output})# 动态paddingpadded_x = self._pad_sequences([x for x in batch_x], maxlen=max_input_len)padded_decoder_input = self._pad_sequences([y['decoder_input'] for y in batch_y], maxlen=max_output_len,padding='post')padded_decoder_output = self._pad_sequences([y['decoder_output'] for y in batch_y],maxlen=max_output_len,padding='post')return [padded_x, padded_decoder_input], padded_decoder_outputdef _pad_sequences(self, sequences, maxlen, padding='pre'):padded = np.zeros((len(sequences), maxlen))for i, seq in enumerate(sequences):if len(seq) > maxlen:seq = seq[:maxlen]if padding == 'pre':padded[i, -len(seq):] = seqelse:padded[i, :len(seq)] = seqreturn padded# 改进的模型架构
def build_seq2seq_model():# 编码器encoder_inputs = Input(shape=(None,), name='Encoder-Input')encoder = build_bert_model(config_path=config_path,checkpoint_path=checkpoint_path,model='encoder',return_keras_model=False,)encoder_outputs = encoder(encoder_inputs)# 解码器decoder_inputs = Input(shape=(None,), name='Decoder-Input')decoder = build_bert_model(config_path=config_path,checkpoint_path=checkpoint_path,model='decoder',application='lm',return_keras_model=False,)decoder_outputs = decoder([decoder_inputs, encoder_outputs])# 连接模型model = Model([encoder_inputs, decoder_inputs], decoder_outputs)# 自定义损失函数(忽略padding)def seq2seq_loss(y_true, y_pred):y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())loss = K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)return K.sum(loss * y_mask) / K.sum(y_mask)model.compile(Adam(learning_rate), loss=seq2seq_loss)return model# 改进的Beam Search
def beam_search(model, input_seq, beam_size=3):encoder_input = tokenizer.encode(input_seq)[0]encoder_output = model.get_layer('bert').predict(np.array([encoder_input]))sequences = [[[tokenizer._token_start_id], 0.0]]for _ in range(max_output_len):all_candidates = []for seq, score in sequences:if seq[-1] == tokenizer._token_end_id:all_candidates.append((seq, score))continuedecoder_input = np.array([seq])decoder_output = model.get_layer('bert_1').predict([decoder_input, encoder_output])[:, -1, :]top_k = np.argsort(decoder_output[0])[-beam_size:]for token in top_k:new_seq = seq + [token]new_score = score + np.log(decoder_output[0][token])all_candidates.append((new_seq, new_score))# 长度归一化ordered = sorted(all_candidates, key=lambda x: x[1]/(len(x[0])+1e-8), reverse=True)sequences = ordered[:beam_size]best_seq = sequences[0][0]return tokenizer.decode(best_seq[1:-1])  # 去除[CLS]和[SEP]# 增强的评估回调
class AdvancedEvaluate(keras.callbacks.Callback):def __init__(self, val_data, sample_size=5):self.val_data = val_dataself.rouge = Rouge()self.samples = val_data.sample(sample_size)def on_epoch_end(self, epoch, logs=None):# 生成示例print("\n生成示例:")for _, row in self.samples.iterrows():generated = beam_search(self.model, row['content'], beam_size)print(f"真实标题: {row['title']}")print(f"生成标题: {generated}\n")# 计算ROUGE分数references = []hypotheses = []for _, row in self.val_data.iterrows():generated = beam_search(self.model, row['content'], beam_size=1)references.append(row['title'])hypotheses.append(generated)scores = self.rouge.get_scores(hypotheses, references, avg=True)print(f"验证集ROUGE-L: {scores['rouge-l']['f']:.4f}")# 主流程
if __name__ == "__main__":# 加载数据full_data = pd.read_csv('train.tsv', sep='\t', names=['title', 'content'])train_data, val_data = train_test_split(full_data, test_size=val_split)# 初始化tokenizertokenizer = Tokenizer(dict_path, do_lower_case=True)# 构建模型model = build_seq2seq_model()model.summary()# 数据生成器train_gen = DataGenerator(train_data, batch_size, mode='train')val_gen = DataGenerator(val_data, batch_size, mode='val')# 训练配置callbacks = [AdvancedEvaluate(val_data),keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1),keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)]# 开始训练model.fit(train_gen,validation_data=val_gen,epochs=epochs,callbacks=callbacks,workers=4,use_multiprocessing=True)


http://www.ppmy.cn/news/1583513.html

相关文章

【蓝桥杯每日一题】3.25

🏝️专栏: 【蓝桥杯备篇】 🌅主页: f狐o狸x “OJ超时不是终点,是算法在提醒你该优化时间复杂度了!” 目录 3.25 差分数组 一、一维差分 题目链接: 题目描述: 解题思路:…

Unity Shader编程】之复杂光照

在Unity Shader的LightMode标签中,除了前向渲染和延迟渲染外,还支持多种渲染模式设置。以下是主要分类及用途: 一、核心渲染路径模式 前向渲染相关 ForwardBase 用于基础光照计算,处理环境光、主平行光、逐顶点/SH光源及光照贴图。…

Windows命令提示符(CMD) 中切换目录主要通过 cd(Change Directory)命令实现

在 Windows命令提示符(CMD) 中切换目录主要通过 cd(Change Directory)命令实现。以下是详细的切换目录方式及常见用法示例: 使用技巧: 1.在文件夹的地址栏,直接输出cmd 即可跳转到当前的文档。…

系统分析师常考题目《论面向对象分析方法及其应用》

软考系统分析师论文范文系列 【摘要】 2022年6月,我所在团队承接了某金融机构“智能信贷风控系统”的设计与开发任务,我作为系统分析师主导了需求分析与架构设计工作。该系统旨在通过数据驱动的动态风控模型,优化信贷审批流程,提…

从零开始的大模型强化学习框架verl解析

之前在职的时候给一些算法的同学讲解过verl的框架设计、实现细节以及超参配置,写这篇文章姑且作为离职修养这段时期的复健。 本文中提到的做法和思路可能随着时间推移有变化,或者是思想迪化,仅代表个人理解。如果有错漏的地方还请指出。 现…

C#基础学习(六)函数的变长参数和参数默认值

什么是变长参数呢? 指的是你传入函数中的形参可以不定项性,你可以输入一个数组进去,就相当于有数组长度那么多的参数可以拿来使用。那么需要怎么来实现呢,就一个关键字params,这个关键字的作用就是当你写在函数参数传入的地方&…

相生、相克、乘侮、复杂病机及对应的脏腑功能联系

一、五行相生关系(母子关系) 五行生序脏腑关系生理表现举例木生火肝(木)滋养心(火)肝血充足则心血旺盛火生土心(火)温煦脾(土)心阳充足则脾胃运化功能正常土…

问题:md文档转换word,html,图片,excel,csv

文章目录 问题:md文档转换word,html,图片,excel,csv,ppt**主要职责****技能要求****发展方向****学习建议****薪资水平** 方案一:AI Markdown内容转换工具打开网站md文档转换wordmd文档转换pdfm…