transformers - 预测中间词

server/2024/9/24 13:21:27/

代码


from transformers import AutoTokenizer#加载编码器
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base', use_fast=True)print(tokenizer)#编码试算
tokenizer.batch_encode_plus(['hide new secretions from the parental units','contains no wit , only labored gags'
])

PreTrainedTokenizerFast(name_or_path='distilroberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 10800, 5069, 117, 22094, 2156, 129, 6348, 3995, 821, 8299, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

加载数据

from datasets import load_dataset, load_from_disk#加载数据
dataset = load_dataset(path='glue', name='sst2')
# dataset = load_from_disk('datas/glue/sst2')#分词,同时删除多余的字段
def f(data):return tokenizer.batch_encode_plus(data['sentence'])dataset = dataset.map(f,batched=True,batch_size=1000,num_proc=4,remove_columns=['sentence', 'idx', 'label'])#过滤掉太短的句子
def f(data):return [len(i) >= 9 for i in data['input_ids']]dataset = dataset.filter(f, batched=True, batch_size=1000, num_proc=4)#截断句子,同时整理成模型需要的格式
def f(data):b = len(data['input_ids'])data['labels'] = data['attention_mask'].copy()for i in range(b):#裁剪长度到9data['input_ids'][i] = data['input_ids'][i][:9]data['attention_mask'][i] = [1] * 9data['labels'][i] = [-100] * 9#input_ids最后一位是2data['input_ids'][i][-1] = 2#每一句话第4个词为mask#tokenizer.get_vocab()['<mask>'] -> 50264data['labels'][i][4] = data['input_ids'][i][4]data['input_ids'][i][4] = 50264return datadataset = dataset.map(f, batched=True, batch_size=1000, num_proc=4)dataset, dataset['train'][0]

import torch
from transformers.data.data_collator import default_data_collator#能够实现随机mask的collate_fn
#如果要使用这个工具类,在数据预处理时就不需要设置数据中的mask,然后让labels=input_ids.copy即可
#from transformers import DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm_probability=0.1)#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data

(5534,{'input_ids': tensor([[    0, 12196,   128,    29, 50264, 10132,    59,  9326,     2],[    0,  1250,     5,  3768, 50264, 34948, 16658,     8,     2],[    0,   627,   936,    16, 50264,   240, 12445,  2129,     2],[    0,  3654,   350, 13185, 50264,    45,   350,  8794,     2],[    0,   560,    28,    56, 50264,  3541, 34261,    19,     2],[    0,   560,   224,    14, 50264,   473,   295,    75,     2],[    0,     6, 14784,  1054, 50264,    10,   686,   865,     2],[    0,  9006,  1495,  2156, 50264, 23317,  4780,     8,     2]]),'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1]]),'labels': tensor([[-100, -100, -100, -100,  144, -100, -100, -100, -100],[-100, -100, -100, -100,   32, -100, -100, -100, -100],[-100, -100, -100, -100,    5, -100, -100, -100, -100],[-100, -100, -100, -100, 2156, -100, -100, -100, -100],[-100, -100, -100, -100,   31, -100, -100, -100, -100],[-100, -100, -100, -100,   24, -100, -100, -100, -100],[-100, -100, -100, -100,   34, -100, -100, -100, -100],[-100, -100, -100, -100,   10, -100, -100, -100, -100]])})

from transformers import AutoModelForCausalLM, RobertaModel#加载模型
#model = AutoModelForCausalLM.from_pretrained('distilroberta-base')#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = RobertaModel.from_pretrained('distilroberta-base')decoder = torch.nn.Linear(768, tokenizer.vocab_size)decoder.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))self.fc = torch.nn.Sequential(torch.nn.Linear(768, 768),torch.nn.GELU(),torch.nn.LayerNorm(768, eps=1e-5),decoder,)#加载预训练模型的参数parameters = AutoModelForCausalLM.from_pretrained('distilroberta-base')self.fc[0].load_state_dict(parameters.lm_head.dense.state_dict())self.fc[2].load_state_dict(parameters.lm_head.layer_norm.state_dict())self.fc[3].load_state_dict(parameters.lm_head.decoder.state_dict())self.criterion = torch.nn.CrossEntropyLoss()def forward(self, input_ids, attention_mask, labels=None):logits = self.pretrained(input_ids=input_ids,attention_mask=attention_mask)logits = logits.last_hidden_statelogits = self.fc(logits)loss = Noneif labels is not None:shifted_logits = logits[:, :-1].reshape(-1, tokenizer.vocab_size)shifted_labels = labels[:, 1:].reshape(-1)loss = self.criterion(shifted_logits, shifted_labels)return {'loss': loss, 'logits': logits}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape



测试

#测试
def test():model.eval()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,)correct = 0total = 0for i, data in enumerate(loader_test):#保存下数据中的label,后面计算正确率要用label = data['labels'][:, 4].clone()#从数据中抹除掉label,防止模型作弊data['labels'] = None#计算with torch.no_grad():out = model(**data)#[8, 10, 50265] -> [8, 10]out = out['logits'].argmax(dim=2)[:, 4]correct += (label == out).sum().item()total += 8if i % 10 == 0:print(i)print(label)print(out)if i == 50:breakprint(correct / total)for i in range(8):print(tokenizer.decode(data['input_ids'][i]))print(tokenizer.decode(label[i]), tokenizer.decode(out[i]))test()

0
tensor([   47, 14838,  5392,    28,    80,  4839,  3668,    29])
tensor([   47, 14633,   749,    28,    80,  4839,  3668,  2156])
10
tensor([ 101,  668,   16,   14,  352,  650, 3961,   16])
tensor([ 101,  773, 7897,   59, 2156, 7397, 3961,   16])
20
tensor([40485,    13,    29, 19303,    33,    16,   295,     9])
tensor([40485,    13,  4839, 16393,    33,  3391,   256,     9])
30
tensor([   53, 33469,  3315,  3723,     7, 24473, 40776,    41])
tensor([11248, 15923,  3315,  3723,     7, 24473, 40776,    41])
40
tensor([ 2435,     5,  2046,  2084, 25210,     9, 42661,     7])
tensor([ 2343,    42,  4265,  8003, 33709,  7021,  9021,     6])
50
tensor([  297, 22258,   998,    64,    10,  1499,    65,  2156])
tensor([  457, 22258,  6545,    64,    10, 10416,    65, 33647])
0.32598039215686275
<s>a strong first<mask>, slightly less</s>quarter  half
<s>( villene<mask> ) seems to</s>
uve uve
<s>going to the<mask> may be just</s>website  gym

from transformers import AdamW
from transformers.optimization import get_scheduler#训练
def train():optimizer = AdamW(model.parameters(), lr=2e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader),optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 50 == 0:label = data['labels'][:, 4]out = out['logits'].argmax(dim=2)[:, 4]correct = (label == out).sum().item()accuracy = correct / 8lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), accuracy, lr)torch.save(model, 'models/2.预测中间词.model')train()

/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warningFutureWarning,
0 18.949838638305664 0.0 1.9996385977593064e-05
50 4.755198001861572 0.625 1.9815684857246115e-05
100 5.0272216796875 0.25 1.963498373689917e-05
150 4.625316143035889 0.125 1.9454282616552225e-05
200 3.663780927658081 0.5 1.927358149620528e-05
250 2.5342917442321777 0.375 1.909288037585833e-05
300 4.986537933349609 0.375 1.8912179255511386e-05
350 3.403028964996338 0.625 1.873147813516444e-05
400 4.041268348693848 0.125 1.8550777014817495e-05
450 3.2715964317321777 0.5 1.8370075894470547e-05
500 2.6591811180114746 0.5 1.81893747741236e-05
550 4.937175750732422 0.25 1.8008673653776656e-05
600 4.845945835113525 0.25 1.7827972533429708e-05
650 1.8658218383789062 0.625 1.7647271413082763e-05
700 3.9473319053649902 0.25 1.7466570292735818e-05
750 2.065851926803589 0.625 1.728586917238887e-05
800 2.957096576690674 0.5 1.7105168052041924e-05
850 4.987250804901123 0.25 1.692446693169498e-05
900 3.5697021484375 0.5 1.674376581134803e-05
950 2.898092746734619 0.5 1.6563064691001085e-05
1000 4.39031457901001 0.375 1.638236357065414e-05

预测

model = torch.load('models/2.预测中间词.model')
test()

2022-12-08


http://www.ppmy.cn/server/20949.html

相关文章

目标检测——小麦穗头数据集

一、重要性及意义 小麦穗头检测在农业领域具有重要意义&#xff0c;主要体现在以下几个方面&#xff1a; 首先&#xff0c;小麦穗头检测可以帮助农民和植物科学家准确评估作物的健康状况和成熟度。通过对小麦穗部的形态特征进行测量和分析&#xff0c;可以及时发现作物生长过…

30.Gateway网关过滤器链执行顺序

请求进入网关会碰到三类过滤器&#xff1a; 1.当前路由过滤器&#xff08;属于GatewayFilter&#xff09; 2.DefaultFilter&#xff08;属于GatewayFilter&#xff09; 3.GlobalFilter&#xff08;属于GlobalFilter&#xff09; 合并到一个过滤器链集合中&#xff0c;排序后…

PotatoPie 4.0 实验教程(41) —— FPGA实现RISC-V 扩展 GPIO UART Timer功能

TD工程介绍 我们提供的TD工程里的RISC-V核默认就开启了GPIO UART扩展&#xff0c;可以看到还有SPI和I2C扩展。因此后面的实验中TD的工程我们基本不怎么修改TD的内容&#xff0c;只需要修改TD工具中Soc_Top.v文件中的TCM0_INITFILE为FD生成的固件名称即可&#xff0c;主要修我以…

LeetCode 面试题 17.08 —— 马戏团人塔

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 首先&#xff0c;我们对人的身高按照从小到大排序&#xff0c;特别注意&#xff0c;对于身高相等的人&#xff0c;要按照体重从高到低排序。这时候&#xff0c;序列已经满足了在上面的人要比下面的人矮一点&#…

【网络安全】跨站脚本攻击(XSS)

专栏文章索引&#xff1a;网络安全 有问题可私聊&#xff1a;QQ&#xff1a;3375119339 目录 一、XSS简介 二、XSS漏洞危害 三、XSS漏洞类型 1.反射型XSS 2.存储型XSS 3.DOM型XSS 四、XSS漏洞防御 一、XSS简介 XSS&#xff08;Cross-Site Scripting&#xff09; XSS 被…

【经典算法】LeetCode31. 下一个排列(Java/C/Python3/GO实现含注释说明,中等)

题目描述 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。例如&#xff0c;arr [1,2,3] &#xff0c;以下这些都可以视作 arr 的排列&#xff1a;[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列。更正式地&…

视频怎么批量压缩?5个好用的电脑软件和在线网站

视频怎么批量压缩&#xff1f;有时候我们需要批量压缩视频来节省存储空间&#xff0c;便于管理文件和空间&#xff0c;快速的传输发送给他人。有些快捷的视频压缩工具却只支持单个视频导入&#xff0c;非常影响压缩效率&#xff0c;那么今天就向大家从软件和在线网站2个角度介绍…

推荐算法顶会论文合集

SIGIR SIGIR 2022 | 推荐系统相关论文分类整理&#xff1a;8.74 https://mp.weixin.qq.com/s/vH0qJ-jGHL7s5wSn7Oy_Nw SIGIR2021推荐系统论文集锦 https://mp.weixin.qq.com/s/N7V_9iqLmVI9_W65IQpOtg SIGIR2020推荐系统论文聚焦&#xff1a; https://mp.weixin.qq.com/s…