Robert+Prompt+对比学习+对抗训练文本分类

news/2025/2/13 2:20:39/

基于Robert的文本分类任务,在此基础上考虑融合对比学习、Prompt和对抗训练来提升模型的文本分类能力,我本地有SST-2数据集的train.txt、dev.txt两个文件,每个文件包含文本内容和标签两列,是个二分类任务,本项目基于pytorch实现。

先介绍一下要融合的三个技术。

1. 对比学习旨在通过对比相似和不相似的样本来提高分类模型的性能。对于每个样本,我们可以在训练时随机选取一个与其相似的样本,并加入到训练中,以鼓励模型更好地学习相似样本的特征,同时在训练时也要随机选取一个不相似的样本,并将其加入到训练中。这可以帮助模型更好地区分不同类别之间的特征。

2. Prompt是一种基于预设文本片段的模型输入方式。通过给定关键词和语法结构,Prompt可以引导模型学习某些具体任务。在文本分类任务中,我们可以给模型预设一些文本提示,以帮助模型更好地学习关键特征。

3. 对抗训练是一种在训练模型时加入干扰数据(扰动)的技术,以增强模型的鲁棒性。在文本分类任务中,我们可以通过向文本中添加词语或修改词语顺序,来生成干扰数据,从而帮助模型更好地区分和理解输入文本。

目录

一、安装依赖库

二、载数据集并进行数据预处理

三、定义模型并训练模型

四、对比学习实现

五、Prompt实现

六、对抗训练实现

七、整个过程封装成一个函数


一、安装依赖库

下面是具体实现的代码,我们将使用PyTorch框架:

首先安装必要的库:

!pip install transformers
!pip install torch
!pip install scikit-learn

然后我们导入需要的库以及设置随机种子以保证实验可重复性等必要组件: 

import random
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmupdevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True

二、载数据集并进行数据预处理

class TextDataset(Dataset):def __init__(self, tokenizer, path, max_length):self.tokenizer = tokenizerself.max_length = max_lengthself.labels = []self.texts = []with open(path) as f:for line in f:line = line.strip().split('\t')text, label = line[0], int(line[1])self.labels.append(label)self.texts.append(text)def __len__(self):return len(self.labels)def __getitem__(self, idx):text, label = self.texts[idx], self.labels[idx]encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,return_tensors='pt')return dict(text=text,input_ids=encoding['input_ids'].squeeze(),attention_mask=encoding['attention_mask'].squeeze(),labels=torch.tensor(label))tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
train_dataset = TextDataset(tokenizer, 'train.txt', 256)
dev_dataset = TextDataset(tokenizer, 'dev.txt', 256)train_sampler = RandomSampler(train_dataset)
dev_sampler = SequentialSampler(dev_dataset)train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=16)
dev_loader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=16)

三、定义模型并训练模型

model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)# We will use a linear decay scheduler
total_steps = len(train_loader) * 5
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)for epoch in range(5):model.train()for batch in train_loader:batch = {k: v.to(device) for k, v in batch.items()}optimizer.zero_grad()outputs = model(**batch)loss = outputs[0]loss.backward()optimizer.step()scheduler.step()model.eval()with torch.no_grad():targets, preds = [], []for batch in dev_loader:batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)targets.extend(batch['labels'].tolist())preds.extend(torch.argmax(outputs.logits, axis=-1).tolist())acc = accuracy_score(targets, preds)f1 = f1_score(targets, preds)print(f'\nEpoch {epoch + 1}:')print(f'Dev Accuracy: {acc:.4f}')print(f'Dev F1 Score: {f1:.4f}')

至此,我们已经成功地训练了一款基于RoBERTa模型的文本分类器。下面是加入融合技术的实现。

四、对比学习实现

def random_similar_text(texts, labels):res_texts, res_labels = [], []for idx, text in enumerate(texts):res_texts.append(text)res_labels.append(labels[idx])# 随机选择一个与当前样本相似的样本,将它加入到数据集中rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])# 随机选择一个不相似的样本,将它加入到数据集中rand_idx = np.random.choice(len(texts), 1)[0]while rand_idx == idx:rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])return res_texts, res_labelstrain_texts, train_labels = random_similar_text(train_dataset.texts, train_dataset.labels)
train_dataset = TextDataset(tokenizer, 'train.txt', 256)

五、Prompt实现

def add_prompt(prompt, texts):return [f'{prompt}{text}' for text in texts]train_dataset.texts = add_prompt('This text is', train_dataset.texts)
dev_dataset.texts = add_prompt('This text is', dev_dataset.texts)

六、对抗训练实现

def add_perturbations(text, n):# 随机选择n个词,并在其周围添加一些噪声生成n个干扰文本words = text.split()idx_list = np.random.choice(len(words), n, replace=False)for idx in idx_list:words[idx] = f'[{words[idx]}]'return ' '.join(words)def generate_perturbations(texts):return [add_perturbations(text, 3) for text in texts]train_dataset.texts += generate_perturbations(train_dataset.texts)
dev_dataset.texts += generate_perturbations(dev_dataset.texts)

七、整个过程封装成一个函数

def train_roberta_with_fusion(train_path, dev_path, num_classes, fusion_type):def random_similar_text(texts, labels):res_texts, res_labels = [], []for idx, text in enumerate(texts):res_texts.append(text)res_labels.append(labels[idx])rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])rand_idx = np.random.choice(len(texts), 1)[0]while rand_idx == idx:rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])return res_texts, res_labelsdef add_perturbations(text, n):words = text.split()idx_list = np.random.choice(len(words), n, replace=False)for idx in idx_list:words[idx] = f'[{words[idx]}]'return ' '.join(words)def add_prompt(prompt, texts):return [f'{prompt}{text}' for text in texts]def generate_perturbations(texts):return [add_perturbations(text, 3) for text in texts]class TextDataset(Dataset):def __init__(self, tokenizer, path, max_length):self.tokenizer = tokenizerself.max_length = max_lengthself.labels = []self.texts = []with open(path) as f:for line in f:line = line.strip().split('\t')text, label = line[0], int(line[1])self.labels.append(label)self.texts.append(text)if fusion_type == 'contrastive':self.texts, self.labels = random_similar_text(self.texts, self.labels)if fusion_type == 'adversarial':self.texts += generate_perturbations(self.texts)if fusion_type == 'prompt':self.texts = add_prompt('This text is', self.texts)def __len__(self):return len(self.labels)def __getitem__(self, idx):text, label = self.texts[idx], self.labels[idx]encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,return_tensors='pt')return dict(text=text,input_ids=encoding['input_ids'].squeeze(),attention_mask=encoding['attention_mask'].squeeze(),labels=torch.tensor(label))device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')tokenizer = RobertaTokenizer.from_pretrained('roberta-base')train_dataset = TextDataset(tokenizer, train_path, 256)dev_dataset = TextDataset(tokenizer, dev_path, 256)train_sampler = RandomSampler(train_dataset)dev_sampler = SequentialSampler(dev_dataset)train_loader = DataLoader(train_dataset, sampler=train


http://www.ppmy.cn/news/350943.html

相关文章

Springboot+Vue服务器盲盒活动

文章目录 一、项目要求二、说明文档1、用户抽奖主页/raffle2、多种奖品链接1、奖品1 discont /discount2、奖品2 CPU upgrade /cpu3、奖品3 Memory upgrade /memory4、奖品4 Increase duration /duration5、奖品5 Send to server /server6、奖品6 Configuration upgrade /upgra…

【C#】并行编程实战:序章

前言 本文主要是基于这本书学习的: 《并行编程实战:基于C#8和.NET Core 3》,我当时买的实体书,长下面这个样子。我买了大概浏览了一下,感觉内容还行(基本都是没见过的新东西),所以打…

英语学习:P开头

pace 步子,节奏 pack 包 package 一包 packet 小包装,袋 paddle 浆状物 pain 疼痛 painful 使痛苦的 painter 绘画者 painting 油画 pair 一双 palace 宫殿 pale 苍白的 pan 平底锅 pancake 薄煎饼 panda 熊猫 panic 惊慌 paper 纸 pap…

JHU ICBM DTI Atlas纤维束模板介绍

JHU ICBM DTI Atlas(Johns Hopkins University Diffusion Tensor Imaging Atlas)是一种常用的脑白质纤维束解剖模板。它基于扩散张量成像(DTI)技术,提供了对人类脑部主要纤维束的定位和可视化。该纤维束解剖模板是通过对大量健康志愿者的脑部DTI图像进行分析和统计得出的。…

android浏览器对比评测,Android浏览器对比测试:QQ浏览器大幅领先

下载速度:QQ浏览器比UC浏览器快 接下来,再来测试下载速度。为了保证下载速度的公正性,我们选择从手机新浪网下载同一个大小为4354K的天气应用。UC浏览器耗时1分11秒,QQ浏览器耗时56.2秒。 下载页面截屏:左为UC浏览器&a…

计算机配件对比,电路板对比_手机配件评测_太平洋电脑网PConline

拆解详情 原装充电头内部构造 高压区电容配置 初级侧原件分布 旁边的塑料盖 将所有的高压原件固定在内部 三星原装的充电头产品,外壳较厚,内部固定较为稳固,要拆开,也费了不少时间,而拆开之后,可以看到&…

miui系统android os,color os对比miui 一加手机刷Color OS与MIUI系统体验对比评测

color os对比miui,下面脚本之家小编将为大家带来一加手机刷Color OS与MIUI系统体验对比评测,感兴趣的朋友可以过来看一看! 一加创始人刘作虎曾在一加发布会上说,一加是一个开放性的平台,在未来将会加入对于MIUI&#x…

realme x2 深度测试打不开_realme X2 Pro手机使用深度对比实用评测

realme X2 Pro怎么样?很多小伙伴们都还不知道,下面小编为大家整理了realme X2 Pro全面评测,一起来看看吧。 realme X2 Pro怎么样 realme X2 Pro将采用双立体声扬声器,支持杜比全景声和Hi-Res认证等。 硬件方面,realme …