目的
本文主要介绍如何实现fastNLP 来复现今年发表在顶会的一篇论文BertGCN: Transductive Text Classification by Combining GCN and BERT。
FastNLP配置
本文采用的fastNLP版本号为0.6.0,可采用一下命令来安装
pip install -b dev https://github.com/fastnlp/fastNLP.git
python setup.py build
python setup.py install
数据预处理
论文采用的架构是bert和gcn,故数据集需要分两步来处理,第一步是将数据集处理成一张图,得到图的邻接矩阵,第二步将其处理成适应bert输入的序列形式。由于FastNLP封装了论文中采用的5个数据集的loader函数和PMIBuildGraph函数,故数据处理的代码变得很简单,具体如下:
class PrepareData:def __init__(self, args):self.arg = argsself.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)if self.arg.dataset == 'mr':data_bundle, adj, target_vocab = self._get_input(MRLoader, MRPmiGraphPipe, args.dev_ratio)elif self.arg.dataset == 'R8':data_bundle, adj, target_vocab = self._get_input(R8Loader, R8PmiGraphPipe, args.dev_ratio)elif self.arg.dataset == 'R52':data_bundle, adj, target_vocab = self._get_input(R52Loader, R52PmiGraphPipe, args.dev_ratio)elif self.arg.dataset == 'ohsumed':data_bundle, adj, target_vocab = self._get_input(OhsumedLoader, OhsumedPmiGraphPipe, args.dev_ratio)elif self.arg.dataset == '20ng':data_bundle, adj, target_vocab = self._get_input(NG20Loader, NG20PmiGraphPipe, args.dev_ratio)else:raise RuntimeError('输入数据集错误,请更改为["mr", "R8", "R52", "ohsumed", "20ng"]')self.data_bundle = data_bundleself.target_vocab = target_vocab## 论文中的memory bank实现形式feats = th.FloatTensor(th.randn((adj.shape[0], args.embed_size)))self.graph_info = {"adj": adj, "feats": feats}def _get_input(self, loader:loader, buildGraph, dev_ratio=0.2):##加载数据集load, bg = loader(), buildGraph()data_bundle = load.load(load.download(dev_ratio=dev_ratio))adj, index = bg.build_graph(data_bundle)## 添加doc标签,以便于在图中定位文档的位置data_bundle.get_dataset('train').add_field('doc_id', index[0])data_bundle.get_dataset('dev').add_field('doc_id', index[1])data_bundle.get_dataset('test').add_field('doc_id', index[2])## 使用bert的分词器对数据文本进行分词data_bundle.get_dataset('train').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')data_bundle.get_dataset('dev').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')data_bundle.get_dataset('test').apply_field(lambda x: self.tokenizer(x.replace('\\', ''), truncation=True, max_length=self.arg.max_len,padding='max_length').input_ids, 'raw_words', 'input_ids')#将标签映射成数值target_vocab = Vocabulary(padding=None, unknown=None)target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target',no_create_entry_dataset=[data_bundle.get_dataset('dev'), data_bundle.get_dataset('test')])target_vocab.index_dataset(data_bundle.get_dataset('train'),data_bundle.get_dataset('dev'),data_bundle.get_dataset('test'), field_name='target')#将其设置为模型的输入和真实输出data_bundle.get_dataset('train').set_input('doc_id', 'input_ids')data_bundle.get_dataset('dev').set_input('doc_id', 'input_ids')data_bundle.get_dataset('test').set_input('doc_id', 'input_ids')data_bundle.get_dataset('train').set_target('target')data_bundle.get_dataset('dev').set_target('target')data_bundle.get_dataset('test').set_target('target')##->>>>>>>>>>>>>>>>>>>>>>>>>>adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)adj = preprocess_adj(adj)return data_bundle, adj, target_vocab
可以看到其主要代码是将图中文档id与文本数据对应,使用bert的分词器来分词,将文档的标签映射成数值;大大减少了数据预处理的复杂度。
模型构建
根据论文的描述,我们将memory bank集成在模型中,在训练过程中能实时更新文档的feature向量。具体见如下代码:
class BertGCN(nn.Module):def __init__(self, pretrained_model='roberta_base', nb_class=20, gcn_hidden_size=256,m=0.3, dropout=0.5, graph_info=None):super(BertGCN, self).__init__()self.bert_model = BertModel.from_pretrained(pretrained_model)self.feat_dim = list(self.bert_model.modules())[-2].out_featuresself.clssifier = th.nn.Linear(self.feat_dim, nb_class)self.gcn = GCN(self.feat_dim, gcn_hidden_size, nb_class, dropout=dropout)self.graph_info = graph_infoself.m = mdef forward(self, input_ids, doc_id):attention_mask = input_ids > 0cls_feats = self.bert_model(input_ids, attention_mask)[0][:, 0]cls_pred = self.clssifier(cls_feats)self.graph_info['feats'][doc_id] = cls_feats.detach()gcn_pred = self.gcn(self.graph_info['feats'],self.graph_info['adj'])[doc_id]pred = gcn_pred*self.m + (1-self.m)*cls_predreturn {'pred': pred}
代码中graph_info就是论文中提出的memory bank,实现batch化更新图上的参数。参数m是论文提出gcn和bert的线性组合比例,其为超参数。
训练和测试
FastNLP提供了Trianer,Tester函数,以及常见callback,sheduler等函数,具体可以自己探索,可以避免自己重复造轮子。具体代码如下:
from fastNLP import Tester, Trainer, CrossEntropyLoss, AccuracyMetric, EarlyStopCallback, LRScheduler
from data_loader import PrepareData
from model import BertGCN
import argparse
import torch as th
parse = argparse.ArgumentParser()parse.add_argument('--dataset', default="20ng", help="[mr, 20ng, R8, R52, ohsumed]")
parse.add_argument("--embed_size", default=768)
parse.add_argument("--gcn_hidden_size", default=256)
# parse.add_argument("--cls_type", default=2)
parse.add_argument("--devices_gpu", default=[0, 1, 2])
parse.add_argument("--lr", default=2e-5, help="learning rate")
parse.add_argument("--bert_lr", default=2e-5)
parse.add_argument("--gcn_lr", default=2e-3)
parse.add_argument("--batch_size", default=32)
parse.add_argument("--max_len", default=128)
parse.add_argument("--p", default=0.3)
parse.add_argument("--pretrained_model", default='bert-base-uncased')
parse.add_argument("--nb_epoch", default=10)
parse.add_argument("--dropout", default=0.5)
parse.add_argument("--dev_ratio", default=0.2)
arg = parse.parse_args()
device = th.device("cuda")
## PrePareData
print("Data Loading")
pd = PrepareData(arg)
pd.graph_info['feats'] = pd.graph_info['feats'].to(device)
pd.graph_info['adj'] = pd.graph_info['adj'].to(device)
arg.cls_type = len(pd.target_vocab)### Load Model
print("Load Model")
model = BertGCN(arg.pretrained_model, arg.cls_type, arg.gcn_hidden_size,arg.p, arg.dropout, pd.graph_info)optim = th.optim.Adam([{'params': model.gcn.parameters(), 'lr': arg.gcn_lr},{'params': model.bert_model.parameters(), 'lr': arg.bert_lr},{'params': model.clssifier.parameters(), 'lr': arg.bert_lr},], lr=arg.lr)scheduler = th.optim.lr_scheduler.MultiStepLR(optim, milestones=[30], gamma=0.1)
callback = [EarlyStopCallback(10), LRScheduler(scheduler)]trainer = Trainer(pd.data_bundle.get_dataset('train'), model, loss=CrossEntropyLoss(target='target'),optimizer=optim, n_epochs=arg.nb_epoch, device=device, callbacks=callback,batch_size=arg.batch_size,dev_data=pd.data_bundle.get_dataset('dev'), metrics=AccuracyMetric(target='target'))trainer.train()tester = Tester(pd.data_bundle.get_dataset('test'), model, metrics=AccuracyMetric(target='target'),device=device)tester.test()
到此论文的复现就成功,只需要一键python run就行了。
下面给出未调参的结果, 暂时只跑了BertGCN:
Model | R8 | R52 | ohsumed | 20ng | mr |
---|---|---|---|---|---|
BertGCN(论文) | 98.1 | 96.6 | 72.8 | 89.3 | 86.0 |
BertGCN(复现) | 98.127 | 96.3006 | 71.9515 | 86.6304 | 86.2127 |
总结
跟作者给出的code对比,可以明显的看出使用FastNLP的代码量更小,且复现的结果跟原文的结果很接近。
BertGCN的完整代码见
添加链接描述