代码
from transformers import AutoTokenizer#加载编码器
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base', use_fast=True)print(tokenizer)#编码试算
tokenizer.batch_encode_plus(['hide new secretions from the parental units','contains no wit , only labored gags'
])
PreTrainedTokenizerFast(name_or_path='distilroberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 10800, 5069, 117, 22094, 2156, 129, 6348, 3995, 821, 8299, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
加载数据
from datasets import load_dataset, load_from_disk#加载数据
dataset = load_dataset(path='glue', name='sst2')
# dataset = load_from_disk('datas/glue/sst2')#分词,同时删除多余的字段
def f(data):return tokenizer.batch_encode_plus(data['sentence'])dataset = dataset.map(f,batched=True,batch_size=1000,num_proc=4,remove_columns=['sentence', 'idx', 'label'])#过滤掉太短的句子
def f(data):return [len(i) >= 9 for i in data['input_ids']]dataset = dataset.filter(f, batched=True, batch_size=1000, num_proc=4)#截断句子,同时整理成模型需要的格式
def f(data):b = len(data['input_ids'])data['labels'] = data['attention_mask'].copy()for i in range(b):#裁剪长度到9data['input_ids'][i] = data['input_ids'][i][:9]data['attention_mask'][i] = [1] * 9data['labels'][i] = [-100] * 9#input_ids最后一位是2data['input_ids'][i][-1] = 2#每一句话第4个词为mask#tokenizer.get_vocab()['<mask>'] -> 50264data['labels'][i][4] = data['input_ids'][i][4]data['input_ids'][i][4] = 50264return datadataset = dataset.map(f, batched=True, batch_size=1000, num_proc=4)dataset, dataset['train'][0]
import torch
from transformers.data.data_collator import default_data_collator#能够实现随机mask的collate_fn
#如果要使用这个工具类,在数据预处理时就不需要设置数据中的mask,然后让labels=input_ids.copy即可
#from transformers import DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm_probability=0.1)#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data
(5534,{'input_ids': tensor([[ 0, 12196, 128, 29, 50264, 10132, 59, 9326, 2],[ 0, 1250, 5, 3768, 50264, 34948, 16658, 8, 2],[ 0, 627, 936, 16, 50264, 240, 12445, 2129, 2],[ 0, 3654, 350, 13185, 50264, 45, 350, 8794, 2],[ 0, 560, 28, 56, 50264, 3541, 34261, 19, 2],[ 0, 560, 224, 14, 50264, 473, 295, 75, 2],[ 0, 6, 14784, 1054, 50264, 10, 686, 865, 2],[ 0, 9006, 1495, 2156, 50264, 23317, 4780, 8, 2]]),'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1]]),'labels': tensor([[-100, -100, -100, -100, 144, -100, -100, -100, -100],[-100, -100, -100, -100, 32, -100, -100, -100, -100],[-100, -100, -100, -100, 5, -100, -100, -100, -100],[-100, -100, -100, -100, 2156, -100, -100, -100, -100],[-100, -100, -100, -100, 31, -100, -100, -100, -100],[-100, -100, -100, -100, 24, -100, -100, -100, -100],[-100, -100, -100, -100, 34, -100, -100, -100, -100],[-100, -100, -100, -100, 10, -100, -100, -100, -100]])})
from transformers import AutoModelForCausalLM, RobertaModel#加载模型
#model = AutoModelForCausalLM.from_pretrained('distilroberta-base')#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = RobertaModel.from_pretrained('distilroberta-base')decoder = torch.nn.Linear(768, tokenizer.vocab_size)decoder.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))self.fc = torch.nn.Sequential(torch.nn.Linear(768, 768),torch.nn.GELU(),torch.nn.LayerNorm(768, eps=1e-5),decoder,)#加载预训练模型的参数parameters = AutoModelForCausalLM.from_pretrained('distilroberta-base')self.fc[0].load_state_dict(parameters.lm_head.dense.state_dict())self.fc[2].load_state_dict(parameters.lm_head.layer_norm.state_dict())self.fc[3].load_state_dict(parameters.lm_head.decoder.state_dict())self.criterion = torch.nn.CrossEntropyLoss()def forward(self, input_ids, attention_mask, labels=None):logits = self.pretrained(input_ids=input_ids,attention_mask=attention_mask)logits = logits.last_hidden_statelogits = self.fc(logits)loss = Noneif labels is not None:shifted_logits = logits[:, :-1].reshape(-1, tokenizer.vocab_size)shifted_labels = labels[:, 1:].reshape(-1)loss = self.criterion(shifted_logits, shifted_labels)return {'loss': loss, 'logits': logits}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape
测试
#测试
def test():model.eval()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,)correct = 0total = 0for i, data in enumerate(loader_test):#保存下数据中的label,后面计算正确率要用label = data['labels'][:, 4].clone()#从数据中抹除掉label,防止模型作弊data['labels'] = None#计算with torch.no_grad():out = model(**data)#[8, 10, 50265] -> [8, 10]out = out['logits'].argmax(dim=2)[:, 4]correct += (label == out).sum().item()total += 8if i % 10 == 0:print(i)print(label)print(out)if i == 50:breakprint(correct / total)for i in range(8):print(tokenizer.decode(data['input_ids'][i]))print(tokenizer.decode(label[i]), tokenizer.decode(out[i]))test()
0
tensor([ 47, 14838, 5392, 28, 80, 4839, 3668, 29])
tensor([ 47, 14633, 749, 28, 80, 4839, 3668, 2156])
10
tensor([ 101, 668, 16, 14, 352, 650, 3961, 16])
tensor([ 101, 773, 7897, 59, 2156, 7397, 3961, 16])
20
tensor([40485, 13, 29, 19303, 33, 16, 295, 9])
tensor([40485, 13, 4839, 16393, 33, 3391, 256, 9])
30
tensor([ 53, 33469, 3315, 3723, 7, 24473, 40776, 41])
tensor([11248, 15923, 3315, 3723, 7, 24473, 40776, 41])
40
tensor([ 2435, 5, 2046, 2084, 25210, 9, 42661, 7])
tensor([ 2343, 42, 4265, 8003, 33709, 7021, 9021, 6])
50
tensor([ 297, 22258, 998, 64, 10, 1499, 65, 2156])
tensor([ 457, 22258, 6545, 64, 10, 10416, 65, 33647])
0.32598039215686275
<s>a strong first<mask>, slightly less</s>quarter half
<s>( villene<mask> ) seems to</s>
uve uve
<s>going to the<mask> may be just</s>website gym
from transformers import AdamW
from transformers.optimization import get_scheduler#训练
def train():optimizer = AdamW(model.parameters(), lr=2e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader),optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 50 == 0:label = data['labels'][:, 4]out = out['logits'].argmax(dim=2)[:, 4]correct = (label == out).sum().item()accuracy = correct / 8lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), accuracy, lr)torch.save(model, 'models/2.预测中间词.model')train()
/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warningFutureWarning,
0 18.949838638305664 0.0 1.9996385977593064e-05
50 4.755198001861572 0.625 1.9815684857246115e-05
100 5.0272216796875 0.25 1.963498373689917e-05
150 4.625316143035889 0.125 1.9454282616552225e-05
200 3.663780927658081 0.5 1.927358149620528e-05
250 2.5342917442321777 0.375 1.909288037585833e-05
300 4.986537933349609 0.375 1.8912179255511386e-05
350 3.403028964996338 0.625 1.873147813516444e-05
400 4.041268348693848 0.125 1.8550777014817495e-05
450 3.2715964317321777 0.5 1.8370075894470547e-05
500 2.6591811180114746 0.5 1.81893747741236e-05
550 4.937175750732422 0.25 1.8008673653776656e-05
600 4.845945835113525 0.25 1.7827972533429708e-05
650 1.8658218383789062 0.625 1.7647271413082763e-05
700 3.9473319053649902 0.25 1.7466570292735818e-05
750 2.065851926803589 0.625 1.728586917238887e-05
800 2.957096576690674 0.5 1.7105168052041924e-05
850 4.987250804901123 0.25 1.692446693169498e-05
900 3.5697021484375 0.5 1.674376581134803e-05
950 2.898092746734619 0.5 1.6563064691001085e-05
1000 4.39031457901001 0.375 1.638236357065414e-05
预测
model = torch.load('models/2.预测中间词.model')
test()
2022-12-08