NLP transformers - 翻译

server/2024/9/22 16:51:40/
python">from transformers import AutoTokenizer#加载编码器
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',use_fast=True)print(tokenizer)#编码试算
tokenizer.batch_encode_plus([['Hello, this one sentence!', 'This is another sentence.']])

python">PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})
{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

python">from datasets import load_dataset, load_from_disk#加载数据
dataset = load_dataset(path='wmt16', name='ro-en')
# dataset = load_from_disk('datas/wmt16/ro-en')#采样,数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))#数据预处理
def preprocess_function(data):#取出数据中的en和roen = [ex['en'] for ex in data['translation']]ro = [ex['ro'] for ex in data['translation']]#源语言直接编码就行了data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)#目标语言在特殊模块中编码with tokenizer.as_target_tokenizer():data['labels'] = tokenizer.batch_encode_plus(ro, max_length=128, truncation=True)['input_ids']return datadataset = dataset.map(function=preprocess_function,batched=True,batch_size=1000,num_proc=4,remove_columns=['translation'])print(dataset['train'][0])dataset

python">{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
DatasetDict({train: Dataset({features: ['input_ids', 'attention_mask', 'labels'],num_rows: 20000})validation: Dataset({features: ['input_ids', 'attention_mask', 'labels'],num_rows: 200})test: Dataset({features: ['input_ids', 'attention_mask', 'labels'],num_rows: 200})
})

python">#这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入
#from transformers import DataCollatorForSeq2Seq
#DataCollatorForSeq2Seq(tokenizer, model=model)import torch#数据批处理函数
def collate_fn(data):#求最长的labelmax_length = max([len(i['labels']) for i in data])#把所有的label都补pad到最长for i in data:pads = [-100] * (max_length - len(i['labels']))i['labels'] = i['labels'] + pads#把多个数据整合成一个tensordata = tokenizer.pad(encoded_inputs=data,padding=True,max_length=None,pad_to_multiple_of=None,return_tensors='pt',)#定义decoder_input_idsdata['decoder_input_ids'] = torch.full_like(data['labels'],tokenizer.get_vocab()['<pad>'],dtype=torch.long)data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]data['decoder_input_ids'][data['decoder_input_ids'] ==-100] = tokenizer.get_vocab()['<pad>']return datadata = [{'input_ids': [21603, 10, 37, 3719, 13],'attention_mask': [1, 1, 1, 1, 1],'labels': [10455, 120, 80]
}, {'input_ids': [21603, 10, 7086, 8408, 563],'attention_mask': [1, 1, 1, 1, 1],'labels': [301, 53, 4074, 1669]
}]collate_fn(data)['decoder_input_ids']

python">tensor([[59542, 10455,   120,    80],[59542,   301,    53,  4074]])

python">import torch#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=8,collate_fn=collate_fn,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breakfor k, v in data.items():print(k, v.shape, v[:2])len(loader)

python">from transformers import AutoModelForSeq2SeqLM, MarianModel#加载模型
#model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = MarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-ro')self.register_buffer('final_logits_bias',torch.zeros(1, tokenizer.vocab_size))self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)#加载预训练模型的参数parameters = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')self.fc.load_state_dict(parameters.lm_head.state_dict())self.criterion = torch.nn.CrossEntropyLoss()def forward(self, input_ids, attention_mask, labels, decoder_input_ids):logits = self.pretrained(input_ids=input_ids,attention_mask=attention_mask,decoder_input_ids=decoder_input_ids)logits = logits.last_hidden_statelogits = self.fc(logits) + self.final_logits_biasloss = self.criterion(logits.flatten(end_dim=1), labels.flatten())return {'loss': loss, 'logits': logits}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)#out = model(**data)
#out['loss'], out['logits'].shape

python">from datasets import load_metric#加载评价函数
metric = load_metric(path='sacrebleu')#试算
metric.compute(predictions=['hello there', 'general kenobi'],references=[['hello there'], ['general kenobi']])

python">

测试

python">#测试
def test():model.eval()#数据加载器loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],batch_size=8,collate_fn=collate_fn,shuffle=True,drop_last=True,)predictions = []references = []for i, data in enumerate(loader_test):#计算with torch.no_grad():out = model(**data)pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))label = tokenizer.batch_decode(data['decoder_input_ids'])predictions.extend(pred)references.extend(label)if i % 2 == 0:print(i)input_ids = tokenizer.decode(data['input_ids'][0])print('input_ids=', input_ids)print('pred=', pred[0])print('label=', label[0])if i == 10:breakreferences = [[j] for j in references]metric_out = metric.compute(predictions=predictions, references=references)print(metric_out)test()

python">

python">from transformers import AdamW
from transformers.optimization import get_scheduler#训练
def train():optimizer = AdamW(model.parameters(), lr=2e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader),optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 50 == 0:out = out['logits'].argmax(dim=2)correct = (data['decoder_input_ids'] == out).sum().item()total = data['decoder_input_ids'].shape[1] * 8accuracy = correct / totaldel correctdel totalpredictions = []references = []for j in range(8):pred = tokenizer.decode(out[j])label = tokenizer.decode(data['decoder_input_ids'][j])predictions.append(pred)references.append([label])metric_out = metric.compute(predictions=predictions,references=references)lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), accuracy, metric_out, lr)torch.save(model, 'models/7.翻译.model')train()

python">

python">model = torch.load('models/7.翻译.model')
test()

python">

python">

python">

python">

python">


http://www.ppmy.cn/server/23898.html

相关文章

js之探索浏览器对象模型

浏览器对象模型&#xff08;Browser Object Model, BOM&#xff09;是Web开发中的重要组成部分&#xff0c;它提供了一种与浏览器交互的方式&#xff0c;允许开发人员控制浏览器窗口、处理用户输入、管理浏览历史等。在本文中&#xff0c;我们将深入探讨BOM的核心概念、结构以及…

JVM(Jvm如何管理空间?对象如何存储、管理?)

Jvm如何管理空间&#xff08;Java运行时数据区域与分配空间的方式&#xff09; ⭐运行时数据区域 程序计数器 程序计数器&#xff08;PC&#xff09;&#xff0c;是一块较小的内存空。它可以看作是当前线程所执行的字节码的行号指示器。Java虚拟机的多线程是通过时间片轮转调…

VUE的生命周期图和各函数

函数 beforeCreate(){ }, created(){ }, beforeMount(){ }, mounted(){ }, beforeUpdate(){ }, updated(){ }, beforeDestroy(){ }, destroyed(){ } 创建时生命周期图 运行时生命周期图

【MHA】MySQL高可用MHA源码1-主库故障监控

1 阅读之前的准备工作 1 一个IDE工具 &#xff0c;博主自己尝试了vscode安装perl的插件&#xff0c;但是函数 、变量 、模块等都不能跳转&#xff0c;阅读起来不是很方便。后来尝试使用了pycharm安装perl插件&#xff0c;阅读支持跳转&#xff0c;自己也能写一些简单的测试样例…

飞行汽车飞行控制系统功能详解

飞行汽车是一种创新的交通工具&#xff0c;结合了汽车和飞机的特点。它可以在陆地上行驶&#xff0c;同时也具备在空中飞行的能力。飞行汽车的概念已经存在多年&#xff0c;并且近年来随着技术的进步和研发的深入&#xff0c;这种交通工具正在逐渐从概念走向现实。 飞行汽车的…

LeetCode-非递增子序列

每日一题 今天刷的依旧是一道中等题&#xff0c;不过感觉今天这道题是中等难度里面比较难的题了&#xff0c;思考了很长时间。过程感觉比较难以理解。 题目要求 给你一个整数数组 nums &#xff0c;找出并返回所有该数组中不同的递增子序列&#xff0c;递增子序列中 至少有两…

第27篇 Spring简介

Spring框架是Java企业级应用的主流框架&#xff0c;其主要基于IoC&#xff08;Inversion of Control&#xff0c;控制反转&#xff09;和DI&#xff08;Dependency Injection&#xff0c;依赖注入&#xff09;设计原则。Spring的核心语法主要包括Bean的定义、装配、自动扫描、A…

利用GaussDB的可观测性能力构建故障模型

D-SMART高斯专版已经开发了几个月了&#xff0c;目前主要技术问题都已经解决&#xff0c;也能够初步看到大概的面貌了。有朋友问我&#xff0c;GaussDB不已经有了TPOPS了&#xff0c;为什么你们还要开发D-SMART高斯专版呢&#xff1f; 实际上TPOPS和D-SMART虽然都可以用于Gaus…