BertGCN的fastNLP实现

news/2024/11/14 14:22:18/

目的

本文主要介绍如何实现fastNLP 来复现今年发表在顶会的一篇论文BertGCN: Transductive Text Classification by Combining GCN and BERT。

FastNLP配置

本文采用的fastNLP版本号为0.6.0,可采用一下命令来安装

pip install -b dev https://github.com/fastnlp/fastNLP.git
python setup.py build
python setup.py install

数据预处理

论文采用的架构是bert和gcn,故数据集需要分两步来处理,第一步是将数据集处理成一张图,得到图的邻接矩阵,第二步将其处理成适应bert输入的序列形式。由于FastNLP封装了论文中采用的5个数据集的loader函数和PMIBuildGraph函数,故数据处理的代码变得很简单,具体如下:

class PrepareData:def __init__(self, args):self.arg = argsself.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)if self.arg.dataset == 'mr':data_bundle, adj, target_vocab = self._get_input(MRLoader, MRPmiGraphPipe,  args.dev_ratio)elif self.arg.dataset == 'R8':data_bundle, adj, target_vocab = self._get_input(R8Loader, R8PmiGraphPipe,  args.dev_ratio)elif self.arg.dataset == 'R52':data_bundle, adj, target_vocab = self._get_input(R52Loader, R52PmiGraphPipe,  args.dev_ratio)elif self.arg.dataset == 'ohsumed':data_bundle, adj, target_vocab = self._get_input(OhsumedLoader, OhsumedPmiGraphPipe,  args.dev_ratio)elif self.arg.dataset == '20ng':data_bundle, adj, target_vocab = self._get_input(NG20Loader, NG20PmiGraphPipe,  args.dev_ratio)else:raise RuntimeError('输入数据集错误,请更改为["mr", "R8", "R52", "ohsumed", "20ng"]')self.data_bundle = data_bundleself.target_vocab = target_vocab## 论文中的memory bank实现形式feats = th.FloatTensor(th.randn((adj.shape[0], args.embed_size)))self.graph_info = {"adj": adj, "feats": feats}def _get_input(self, loader:loader, buildGraph, dev_ratio=0.2):##加载数据集load, bg = loader(), buildGraph()data_bundle = load.load(load.download(dev_ratio=dev_ratio))adj, index = bg.build_graph(data_bundle)## 添加doc标签,以便于在图中定位文档的位置data_bundle.get_dataset('train').add_field('doc_id', index[0])data_bundle.get_dataset('dev').add_field('doc_id', index[1])data_bundle.get_dataset('test').add_field('doc_id', index[2])## 使用bert的分词器对数据文本进行分词data_bundle.get_dataset('train').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')data_bundle.get_dataset('dev').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')data_bundle.get_dataset('test').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')#将标签映射成数值target_vocab = Vocabulary(padding=None, unknown=None)target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target',no_create_entry_dataset=[data_bundle.get_dataset('dev'), data_bundle.get_dataset('test')])target_vocab.index_dataset(data_bundle.get_dataset('train'),data_bundle.get_dataset('dev'),data_bundle.get_dataset('test'), field_name='target')#将其设置为模型的输入和真实输出data_bundle.get_dataset('train').set_input('doc_id', 'input_ids')data_bundle.get_dataset('dev').set_input('doc_id', 'input_ids')data_bundle.get_dataset('test').set_input('doc_id', 'input_ids')data_bundle.get_dataset('train').set_target('target')data_bundle.get_dataset('dev').set_target('target')data_bundle.get_dataset('test').set_target('target')##->>>>>>>>>>>>>>>>>>>>>>>>>>adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)adj = preprocess_adj(adj)return data_bundle, adj, target_vocab

可以看到其主要代码是将图中文档id与文本数据对应,使用bert的分词器来分词,将文档的标签映射成数值;大大减少了数据预处理的复杂度。

模型构建

根据论文的描述,我们将memory bank集成在模型中,在训练过程中能实时更新文档的feature向量。具体见如下代码:

class BertGCN(nn.Module):def __init__(self, pretrained_model='roberta_base', nb_class=20, gcn_hidden_size=256,m=0.3, dropout=0.5, graph_info=None):super(BertGCN, self).__init__()self.bert_model = BertModel.from_pretrained(pretrained_model)self.feat_dim = list(self.bert_model.modules())[-2].out_featuresself.clssifier = th.nn.Linear(self.feat_dim, nb_class)self.gcn = GCN(self.feat_dim, gcn_hidden_size, nb_class, dropout=dropout)self.graph_info = graph_infoself.m = mdef forward(self, input_ids, doc_id):attention_mask = input_ids > 0cls_feats = self.bert_model(input_ids, attention_mask)[0][:, 0]cls_pred = self.clssifier(cls_feats)self.graph_info['feats'][doc_id] = cls_feats.detach()gcn_pred = self.gcn(self.graph_info['feats'],self.graph_info['adj'])[doc_id]pred = gcn_pred*self.m + (1-self.m)*cls_predreturn {'pred': pred}

代码中graph_info就是论文中提出的memory bank,实现batch化更新图上的参数。参数m是论文提出gcn和bert的线性组合比例,其为超参数。

训练和测试

FastNLP提供了Trianer,Tester函数,以及常见callback,sheduler等函数,具体可以自己探索,可以避免自己重复造轮子。具体代码如下:

from fastNLP import Tester, Trainer, CrossEntropyLoss, AccuracyMetric,  EarlyStopCallback, LRScheduler
from data_loader import PrepareData
from model import BertGCN
import argparse
import torch as th
parse = argparse.ArgumentParser()parse.add_argument('--dataset', default="20ng", help="[mr, 20ng, R8, R52, ohsumed]")
parse.add_argument("--embed_size", default=768)
parse.add_argument("--gcn_hidden_size", default=256)
# parse.add_argument("--cls_type", default=2)
parse.add_argument("--devices_gpu", default=[0, 1, 2])
parse.add_argument("--lr", default=2e-5, help="learning rate")
parse.add_argument("--bert_lr", default=2e-5)
parse.add_argument("--gcn_lr", default=2e-3)
parse.add_argument("--batch_size", default=32)
parse.add_argument("--max_len", default=128)
parse.add_argument("--p", default=0.3)
parse.add_argument("--pretrained_model", default='bert-base-uncased')
parse.add_argument("--nb_epoch", default=10)
parse.add_argument("--dropout", default=0.5)
parse.add_argument("--dev_ratio", default=0.2)
arg = parse.parse_args()
device = th.device("cuda")
## PrePareData
print("Data Loading")
pd = PrepareData(arg)
pd.graph_info['feats'] = pd.graph_info['feats'].to(device)
pd.graph_info['adj'] = pd.graph_info['adj'].to(device)
arg.cls_type = len(pd.target_vocab)### Load Model
print("Load Model")
model = BertGCN(arg.pretrained_model, arg.cls_type, arg.gcn_hidden_size,arg.p, arg.dropout, pd.graph_info)optim = th.optim.Adam([{'params': model.gcn.parameters(), 'lr': arg.gcn_lr},{'params': model.bert_model.parameters(), 'lr': arg.bert_lr},{'params': model.clssifier.parameters(), 'lr': arg.bert_lr},], lr=arg.lr)scheduler = th.optim.lr_scheduler.MultiStepLR(optim, milestones=[30], gamma=0.1)
callback = [EarlyStopCallback(10), LRScheduler(scheduler)]trainer = Trainer(pd.data_bundle.get_dataset('train'), model, loss=CrossEntropyLoss(target='target'),optimizer=optim, n_epochs=arg.nb_epoch, device=device, callbacks=callback,batch_size=arg.batch_size,dev_data=pd.data_bundle.get_dataset('dev'), metrics=AccuracyMetric(target='target'))trainer.train()tester = Tester(pd.data_bundle.get_dataset('test'), model, metrics=AccuracyMetric(target='target'),device=device)tester.test()

到此论文的复现就成功,只需要一键python run就行了。

下面给出未调参的结果, 暂时只跑了BertGCN:

ModelR8R52ohsumed20ngmr
BertGCN(论文)98.196.672.889.386.0
BertGCN(复现)98.12796.300671.951586.630486.2127

总结

跟作者给出的code对比,可以明显的看出使用FastNLP的代码量更小,且复现的结果跟原文的结果很接近。

BertGCN的完整代码见
添加链接描述


http://www.ppmy.cn/news/139785.html

相关文章

迪文屏OS汇编代码开发-参数修改 保存 翻页(七)

; DWIN OS ;程序功能:上翻页,下翻页,参数修改,保存 ;软件环境: DWIN OS ASM Builder V1.5 ;硬件环境:DW K600平台 ;变量 ;用户数据区地址从0x0600 0000开始分配,目前定义的参数区为40个 最大处方数。 ;参…

安卓app+esp8266+51单片机+光敏电阻+lcd1602实现智能照明系统

本文是本人51单片机和物联网的期末课程设计,没学过打板焊接,只用面包板和公母线实现。 安卓和esp8266控灯主要参考Android Studio设计APP实现与51单片机通过WIFI模块(ESP8266-01S)通讯控制LED灯亮灭的设计源码【详解】_手机app通…

码农的自我修养 - ARM处理器天梯图

ARM芯片族 - 架构 - 内核 - 总线速度列表: ARM GROUP ARM architecture ARM core Bus Speed ARM1 ARMv1 ARM1 ARM2 ARMv2 ARM2 4 MIPS 8 MHz 0.33 DMIPS/MHz ARMv2a ARM250 7 MIPS 12 MHz ARM3 ARMv2a ARM3 12 MIPS 25 MHz 0…

ARM各内核系列整型运算能力对比---DMIPS / MHz

DMIPS:Dhrystone Million Instructions executed Per Second (百万条整数运算指令/秒),用于衡量CPU整数计算能力。 超标量处理器: 是指在一颗处理器内核中实现了指令级并行的一类并行运算。在这里就是 DMIPS/MHz 大于…

RISC-V与ARM

RISC-V与ARM RISC-V 架构RISC-V架构特点ARM 架构RISC-V 与 ARM 指令集架构 (ISA) 基本上是汇编级程序员,或编译器编写者可见的机器部分。 ISA 是软件与硬件相遇的地方。 ISA 定义了机器及其微架构本身可以理解的命令/指令,它还定义了如何存储、访问和实施…

STM32 WAVWM8978简介

​ WAV即WAVE文件,是最常用的数字化声音文件格式之一,其扩展名为“.wav”。符合RIFF(Resource Interchange File Format)文件规范,用于保存Windows平台的音频信息资源,被Windows平台及其应用程序所广泛支持。 WAV格式还支持MS ADP…

ARM在汽车电子电器架构的应用

整理自ARM中国FAE高级经理及技术专家丁先生在集微网的演讲,侵删。 该演讲涵盖了汽车电子电器架构的多个方面,整体包含的知识面非常广。整个演讲非常精彩,也是非常佩服丁先生在汽车电子电器架构及ARM在其中的应用的精彩阐述。 本人也是从事汽…

WAVWM8978

一、WAV WAV即WAVE文件,是最常用的数字化声音文件格式之一,扩展名为‘.wav’,用于保存Windows平台的音频信息资源。WAV格式还支持MS ADPCM、CCITT A LAW 等多种压缩运算法,支持多种音频数字,取样频率和声道。标准格式化的WAV文件和…