对偶对比学习方法在文本分类任务中的应用

news/2025/3/29 19:36:58/

对偶对比学习(Dual Contrastive Learning,DCL)是一种新兴的自监督学习方法,它可以用于学习文本的表示。与传统的对比学习方法不同,DCL使用对偶性原理,将正样本和负样本的对比学习转化为两个对称的任务,从而提高了模型的性能。

在文本分类任务中,DCL可以用于学习文本的表示,从而提高分类的准确性。具体来说,DCL使用两个对称的任务来学习文本的表示:正样本任务和负样本任务。在正样本任务中,DCL将同一篇文本的不同片段作为正样本,将其他文本的任意片段作为负样本,从而学习文本的表示。在负样本任务中,DCL将同一篇文本的任意两个片段作为负样本,将其他文本的任意片段作为正样本,从而学习文本的表示。

通过这种对称的方式,DCL可以有效地学习文本的表示,并提高文本分类的准确性。实验结果表明,DCL在多个文本分类任务中都取得了优秀的性能,比传统的对比学习方法和其他自监督学习方法都要好。因此,DCL是一种非常有潜力的自监督学习方法,可以用于学习文本的表示和其他自然语言处理任务。

目录

一、任务描述

二、代码详解

三、改进思想

3.1  添加随机掩码

3.2 添加同义词替换

四、DCL和SimCLR对比


一、任务描述

基于Robert的文本分类任务,在此基础上融合对偶对比学习(Dual Contrastive Learning,DCL)和对抗训练来提升模型的文本分类能力,我本地有SST-2数据集的train.txt、test.txt、dev.txt三个文件,每个文件包含文本内容和标签两列,用pytorch实现任务。

二、代码详解

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaTokenizer, RobertaModel# 定义模型
class DCL_Roberta(nn.Module):def __init__(self, roberta_model):super(DCL_Roberta, self).__init__()self.roberta = roberta_modelself.hidden_size = roberta_model.config.hidden_size  # 获取模型的隐藏层维度self.dropout = nn.Dropout(0.1)  # 定义dropout层,用于防止过拟合self.classifier = nn.Linear(self.hidden_size, 2)  # 定义分类器,将隐藏层的表示映射到2个类别上def forward(self, input_ids, attention_mask):outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)  # 使用RoBERTa模型获取文本的表示last_hidden_state = outputs.last_hidden_state  # 获取RoBERTa模型最后一层的隐藏层表示# 对偶对比学习pos_hidden = self.dropout(last_hidden_state[:, 0, :].unsqueeze(1).repeat(1, last_hidden_state.shape[1], 1))  # 获取正样本的表示,使用dropout层进行正则化neg_hidden = self.dropout(last_hidden_state.repeat(1, 2, 1).view(last_hidden_state.shape[0], -1, self.hidden_size))  # 获取负样本的表示,使用dropout层进行正则化logits = self.classifier(pos_hidden - neg_hidden)  # 计算正负样本的差异,并通过分类器映射到2个类别上return logits# 定义数据集
class SST2Dataset(Dataset):def __init__(self, file_path, tokenizer):self.tokenizer = tokenizerself.sentences, self.labels = self.load_data(file_path)def load_data(self, file_path):sentences = []labels = []with open(file_path, "r", encoding="utf-8") as f:for line in f:sentence, label = line.strip().split("\t")  # 读取每行文本和标签sentences.append(sentence)labels.append(int(label))return sentences, labelsdef __len__(self):return len(self.sentences)def __getitem__(self, idx):sentence = self.sentences[idx]label = self.labels[idx]inputs = self.tokenizer.encode_plus(sentence, add_special_tokens=True, return_tensors="pt")  # 使用tokenizer对文本进行编码input_ids = inputs["input_ids"].squeeze()  # 获取文本的token id,并去除多余的维度attention_mask = inputs["attention_mask"].squeeze()  # 获取文本的attention mask,并去除多余的维度return input_ids, attention_mask, label# 定义训练函数
def train(model, train_loader, optimizer, criterion, device):model.train()  # 设置模型为训练模式total_loss = 0.0correct = 0for i, (input_ids, attention_mask, label) in enumerate(train_loader):input_ids, attention_mask, label = input_ids.to(device), attention_mask.to(device), label.to(device)  # 将数据移到GPU上optimizer.zero_grad()  # 清空梯度logits = model(input_ids, attention_mask)  # 前向传播,计算模型输出# 对抗训练adversarial_logits = model(input_ids + torch.randn_like(input_ids) * 0.1, attention_mask)  # 对输入进行随机扰动,以增加模型的鲁棒性loss = criterion(logits, label) + criterion(adversarial_logits, label)  # 计算损失函数loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数total_loss += loss.item()  # 累计损失preds = torch.argmax(logits, dim=1)  # 获取模型预测的标签correct += (preds == label).sum().item()  # 计算预测正确的样本数return total_loss / len(train_loader), correct / len(train_loader.dataset)# 定义测试函数
def test(model, test_loader, criterion, device):model.eval()  # 设置模型为评估模式total_loss = 0.0correct = 0with torch.no_grad():  # 不进行梯度计算for i, (input_ids, attention_mask, label) in enumerate(test_loader):input_ids, attention_mask, label = input_ids.to(device), attention_mask.to(device), label.to(device)  # 将数据移到GPU上logits = model(input_ids, attention_mask)  # 前向传播,计算模型输出loss = criterion(logits, label)  # 计算损失函数total_loss += loss.item()  # 累计损失preds = torch.argmax(logits, dim=1)  # 获取模型预测的标签correct += (preds == label).sum().item()  # 计算预测正确的样本数return total_loss / len(test_loader), correct / len(test_loader.dataset)# 加载数据集和tokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
train_dataset = SST2Dataset("train.txt", tokenizer)
test_dataset = SST2Dataset("test.txt", tokenizer)
dev_dataset = SST2Dataset("dev.txt", tokenizer)# 定义超参数
batch_size = 16
learning_rate = 2e-5
epochs = 3# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型、优化器和损失函数
roberta_model = RobertaModel.from_pretrained("roberta-base")  # 加载预训练模型
model = DCL_Roberta(roberta_model).to(device)  # 构建DCL_Roberta模型,并将其移动到GPU上
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 定义Adam优化器
criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数# 定义数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # 定义训练集数据加载器
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)  # 定义测试集数据加载器
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False)  # 定义验证集数据加载器# 训练模型
for epoch in range(epochs):train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)  # 训练模型test_loss, test_acc = test(model, test_loader, criterion, device)  # 在测试集上评估模型dev_loss, dev_acc = test(model, dev_loader, criterion, device)  # 在验证集上评估模型print(f"Epoch {epoch+1}/{epochs} Train Loss: {train_loss:.4f} Train Acc: {train_acc:.4f} Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f} Dev Loss: {dev_loss:.4f} Dev Acc: {dev_acc:.4f}")  # 打印每个epoch的损失和准确率

三、改进思想

添加一些数据增强方法来提升模型的文本分类效果。在这里,我会添加两种数据增强方法:随机掩码和同义词替换。 

3.1  添加随机掩码

首先,我们来添加随机掩码。随机掩码是指在文本中随机选择一些词,然后将它们替换为掩码符号,以模拟部分词汇丢失的情况。这种方法可以帮助模型更好地学习上下文信息。

以下是添加随机掩码的代码实现:

import random# 定义数据集
class SST2Dataset(Dataset):def __init__(self, file_path, tokenizer):self.tokenizer = tokenizerself.sentences, self.labels = self.load_data(file_path)def load_data(self, file_path):sentences = []labels = []with open(file_path, "r", encoding="utf-8") as f:for line in f:sentence, label = line.strip().split("\t")  # 读取每行文本和标签sentences.append(sentence)labels.append(int(label))return sentences, labelsdef random_mask(self, sentence):tokens = self.tokenizer.tokenize(sentence)  # 对文本进行分词mask_indices = random.sample(range(1, len(tokens) - 1), int(len(tokens) * 0.15))  # 随机选择一些词进行掩码for i in mask_indices:tokens[i] = self.tokenizer.mask_token  # 将选中的词替换为掩码符号return self.tokenizer.convert_tokens_to_string(tokens)  # 将分词后的文本转换为字符串def __len__(self):return len(self.sentences)def __getitem__(self, idx):sentence = self.sentences[idx]label = self.labels[idx]masked_sentence = self.random_mask(sentence)  # 对文本进行随机掩码inputs = self.tokenizer.encode_plus(masked_sentence, add_special_tokens=True, return_tensors="pt")  # 使用tokenizer对文本进行编码input_ids = inputs["input_ids"].squeeze()  # 获取文本的token id,并去除多余的维度attention_mask = inputs["attention_mask"].squeeze()  # 获取文本的attention mask,并去除多余的维度return input_ids, attention_mask, label

3.2 添加同义词替换

然后,我们来添加同义词替换。同义词替换是指在文本中随机选择一些词,然后将它们替换为它们的同义词,以增加文本的多样性。这种方法可以帮助模型更好地学习词汇的语义信息。

以下是添加同义词替换的代码实现:

from nltk.corpus import wordnet# 定义数据集
class SST2Dataset(Dataset):def __init__(self, file_path, tokenizer):self.tokenizer = tokenizerself.sentences, self.labels = self.load_data(file_path)def load_data(self, file_path):sentences = []labels = []with open(file_path, "r", encoding="utf-8") as f:for line in f:sentence, label = line.strip().split("\t")  # 读取每行文本和标签sentences.append(sentence)labels.append(int(label))return sentences, labelsdef random_mask(self, sentence):tokens = self.tokenizer.tokenize(sentence)  # 对文本进行分词mask_indices = random.sample(range(1, len(tokens) - 1), int(len(tokens) * 0.15))  # 随机选择一些词进行掩码for i in mask_indices:tokens[i] = self.tokenizer.mask_token  # 将选中的词替换为掩码符号return self.tokenizer.convert_tokens_to_string(tokens)  # 将分词后的文本转换为字符串def synonym_replacement(self, sentence):tokens = self.tokenizer.tokenize(sentence)  # 对文本进行分词for i, token in enumerate(tokens):if token not in [self.tokenizer.cls_token, self.tokenizer.sep_token, self.tokenizer.pad_token]:synsets = wordnet.synsets(token)  # 获取词汇的同义词集合if synsets:synonyms = [synset.lemmas()[0].name() for synset in synsets]  # 获取同义词if synonyms:synonym = random.choice(synonyms)  # 随机选择一个同义词进行替换tokens[i] = synonym  # 将词汇替换为同义词return self.tokenizer.convert_tokens_to_string(tokens)  # 将分词后的文本转换为字符串def __len__(self):return len(self.sentences)def __getitem__(self, idx):sentence = self.sentences[idx]label = self.labels[idx]masked_sentence = self.random_mask(sentence)  # 对文本进行随机掩码replaced_sentence = self.synonym_replacement(masked_sentence)  # 对文本进行同义词替换inputs = self.tokenizer.encode_plus(replaced_sentence, add_special_tokens=True, return_tensors="pt")  # 使用tokenizer对文本进行编码input_ids = inputs["input_ids"].squeeze()  # 获取文本的token id,并去除多余的维度attention_mask = inputs["attention_mask"].squeeze()  # 获取文本的attention mask,并去除多余的维度return input_ids, attention_mask, label

可以选择其中一种或两种数据增强方法来使用,也可以根据需要自行添加其他数据增强方法。 

四、DCL和SimCLR对比

SimCLR和DCL都是常用的对比学习方法,可以用于文本分类任务中。它们的主要区别在于对负样本的构造方式不同。SimCLR使用随机数据增强的方式构造负样本,而DCL使用对偶对比学习的方式构造负样本。具体来说,DCL使用同一文本的不同部分作为正样本和负样本,以便模型更好地学习文本的局部特征。

实际上,SimCLR和DCL在文本分类任务中的表现都比较好。在一些研究中,SimCLR在一些数据集上的表现略优于DCL,而在另一些数据集上,DCL则表现更好。这可能与数据集的特征、模型的架构等因素有关。因此,我们无法确定哪种方法在所有情况下都表现更好。

在实际应用中,我们可以尝试使用SimCLR和DCL两种方法,然后根据实验结果选择更适合我们的任务的方法。另外,我们也可以尝试使用其他对比学习方法,以便找到最适合我们任务的方法。

 


http://www.ppmy.cn/news/310769.html

相关文章

Linux C简易聊天室

对于初学者而已,我们学习的网络编程(如TCP,UDP编程),我们通常都是在局域网内进行通信测试,有时候我们或者会想,我们现在写的内网网络数据和外网的网络数据有什么不同,我们内网的数据是如何走出外…

软考A计划-系统架构师-官方考试指定教程-(15/15)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&am…

leetcode第314场周赛补题

第一题&#xff1a;6200. 处理用时最长的那个任务的员工 原题链接 思路&#xff1a;简单模拟&#xff0c;遍历取最大值即可 class Solution { public:int hardestWorker(int n, vector<vector<int>>& logs) {int res logs[0][0];int sum logs[0][1];for(in…

Android 搜索内容保存历史记录

Android 搜索内容保存历史记录 一、界面布局 主界面布局 activity_main.xml <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http://schemas.and…

Android Studio入门之文本内容、大小、颜色的讲解及实战(附源码 超详细必看)

运行有问题或需要源码请点赞关注收藏后评论区留言或私信博主 一、设置文本的内容 1:在XML文件中通过属性android:text设置文本 <TextViewandroid:layout_width"wrap_content"android:layout_height"wrap_content"android:text"Hello World!"…

Compose (9/N) - 主题 Theme

一、Material Design 直接把任何Composable函数用 MaterialTheme{ } 包裹起来&#xff0c;就可以使用相关属性了。也可以单独将某个属性拿出来使用。 1.1 颜色 Color primary 主色&#xff0c;屏幕和元素都用这个颜色。 primaryVariant 用于区分主色&#xff0c;比如app bar和…

android 登录注册动画,Android开发(14)——动画实战:炫酷登录

本节内容 1.第三方库实现虚化 2.添加输入框和按钮 3.按钮状态 4.键盘隐藏 5.监听焦点改变的事件 6.手臂旋转动画 7.手掌和手臂动画 Demo简介 1.做一个炫酷登录的界面。 image.png 当我们输入密码的时候&#xff0c;猫头鹰会捂住眼睛。点击其他地方后&#xff0c;它的手臂又会张…

Android开发帮助技巧(适用于入门)二

Android配置配置清单&#xff0c;Gradle构建文件&#xff0c;外部化资源部分。 &#xff08;一)配置清单AndroidManifest.xml use-permission和uses-feature和application属于同等级节点。 android:allowBackup 是否允许应用参与备份和恢复基础架构。如果将此属性设为 fals…