目录
1 处理数据
1.1 加载预训练的分词器¶
2 自定义创建数据集
2.1 创建dataset
2.2 自定义collate_fn(数据批量输出的方法)
2.3 创建数据加载器
3 创建模型
4 训练过程代码
5 保存训练好的模型
6 加载保存好的模型
7 测试预测阶段代码
#目前,NLP与CV主要使用transformers库
#框架:主要使用PyTorch
#NLP任务的大体流程:
#处理数据: 中文字符 ---> 数字
#创建数据集。 把处理好的数据变成PyTorch的数据集
#生成模型, 一般使用transformers库,不需要自己建模
#训练预测过程
#配置代理
# import os# os.environ['http_proxy'] = '127.0.0.1:10809'
# os.environ['https_proxy'] = '127.0.0.1:10809'
#这里是本地加载预训练模型,不需要
1 处理数据
1.1 加载预训练的分词器¶
from transformers import AutoTokenizer #AutoTokenizer分词器 可以使中文字符转变成数字#我这里是本地加载的模型文件
tokenizer = AutoTokenizer.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')
print(tokenizer)
BertTokenizerFast(name_or_path='../data/model/gpt2-chinese-cluecorpussmall/', vocab_size=21128, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True), added_tokens_decoder={0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), }
#编码分词试算
text = [ '明朝驿使发,一夜絮征袍.素手抽针冷,那堪把剪刀.裁缝寄远道,几日到临洮.','长安一片月,万户捣衣声.秋风吹不尽,总是玉关情.何日平胡虏,良人罢远征.']
#输出结果为一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'
tokenizer.batch_encode_plus(text)
{'input_ids': [[101, 3209, 3308, 7731, 886, 1355, 117, 671, 1915, 5185, 2519, 6151, 119, 5162, 2797, 2853, 7151, 1107, 117, 6929, 1838, 2828, 1198, 1143, 119, 6161, 5361, 2164, 6823, 6887, 117, 1126, 3189, 1168, 707, 3826, 119, 102], [101, 7270, 2128, 671, 4275, 3299, 117, 674, 2787, 2941, 6132, 1898, 119, 4904, 7599, 1430, 679, 2226, 117, 2600, 3221, 4373, 1068, 2658, 119, 862, 3189, 2398, 5529, 5989, 117, 5679, 782, 5387, 6823, 2519, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
2 自定义创建数据集
2.1 创建dataset
import torchclass Dataset(torch.utils.data.Dataset):def __init__(self):super().__init__()#从本地读取数据with open('../data/datasets/chinese_poems.txt', encoding='utf-8') as f:lines = f.readlines() #读取的每一行数据都会以一个字符串的形式 依次添加到一个列表中#split()函数可以根据指定的分隔符将字符串拆分成多个子字符串,并将这些子字符串存储在一个列表中。#strip()函数默认移除字符串两端的空白字符(包括空格、制表符、换行符等)lines = [line.strip() for line in lines] #输出的lines是一个一维列表,里面的每一行诗都是一个字符串#self.的变量在类里面可以调用self.lines = lines #self.lines是一个列表,里面的元素都是一个个字符串def __len__(self):return len(self.lines)def __getitem__(self, i):"""可以向列表一样通过索引来获取数据"""return self.lines[i]#试跑一下
dataset = Dataset()
len(dataset), dataset[0]
(304752, '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.')
dataset数据集只能一条一条数据的输出,不能一批批数据传输,
需要将datatset变成pytorch中dataloader的数据形式,将数据可以批量输出
2.2 自定义collate_fn(数据批量输出的方法)
def collate_fn(batch):#使用分词器 把中文编码成数字#tokenizer分词器的输出结果data是一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'data = tokenizer.batch_encode_plus(batch, padding=True,truncation=True,max_length=512,return_tensors='pt')#向字典data中添加数据标签目标值labels, 用data原数据中的['input_ids']诗句文字编码来赋值,#克隆一份对原数据无影响data['labels'] = data['input_ids'].clone()return data
2.3 创建数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=4, collate_fn=collate_fn,shuffle=True,drop_last=True)
#dataloader不能直接访问数据,需要for循环来获取数据
#查看第一批数据
for i, data in enumerate(loader):break #只循环一次
i
0
data #data是一个字典, 包含'input_ids'、'token_type_ids'、'attention_mask'、'labels'
{'input_ids': tensor([[ 101, 2708, 4324, 2406, 782, 1777, 1905, 3918, 117, 7345, 5125, 7346,7790, 6387, 4685, 2192, 119, 738, 4761, 5632, 3300, 1921, 1045, 1762,117, 6475, 955, 865, 6778, 4212, 2769, 1412, 119, 102, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 1921, 4495, 671, 4954, 117, 5966, 1434, 3369, 7755, 119, 7755,3323, 2768, 1759, 117, 1759, 5543, 4495, 4289, 119, 5310, 702, 5872,5701, 117, 2899, 6627, 2336, 1880, 119, 3719, 5564, 6762, 1726, 117,6631, 676, 686, 867, 119, 102, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 753, 2399, 3736, 677, 6224, 3217, 2495, 117, 7564, 2682, 7028,3341, 2769, 3313, 1726, 119, 3922, 6862, 686, 7313, 6443, 3160, 2533,117, 4856, 2418, 4685, 6878, 684, 6124, 3344, 119, 102, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 3926, 7599, 711, 2769, 6843, 2495, 5670, 117, 3144, 5108, 7471,3351, 6629, 5946, 4170, 119, 2359, 2512, 2661, 7607, 4904, 3717, 100,117, 3587, 1898, 3009, 3171, 1911, 7345, 6068, 119, 1126, 782, 2157,1762, 3983, 1928, 2279, 117, 671, 4275, 756, 4495, 3717, 2419, 1921,119, 4007, 4706, 5679, 3301, 3187, 1962, 6983, 117, 3634, 2552, 2347,2899, 736, 3736, 6804, 119, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2708, 4324, 2406, 782, 1777, 1905, 3918, 117, 7345, 5125, 7346,7790, 6387, 4685, 2192, 119, 738, 4761, 5632, 3300, 1921, 1045, 1762,117, 6475, 955, 865, 6778, 4212, 2769, 1412, 119, 102, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 1921, 4495, 671, 4954, 117, 5966, 1434, 3369, 7755, 119, 7755,3323, 2768, 1759, 117, 1759, 5543, 4495, 4289, 119, 5310, 702, 5872,5701, 117, 2899, 6627, 2336, 1880, 119, 3719, 5564, 6762, 1726, 117,6631, 676, 686, 867, 119, 102, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 753, 2399, 3736, 677, 6224, 3217, 2495, 117, 7564, 2682, 7028,3341, 2769, 3313, 1726, 119, 3922, 6862, 686, 7313, 6443, 3160, 2533,117, 4856, 2418, 4685, 6878, 684, 6124, 3344, 119, 102, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0],[ 101, 3926, 7599, 711, 2769, 6843, 2495, 5670, 117, 3144, 5108, 7471,3351, 6629, 5946, 4170, 119, 2359, 2512, 2661, 7607, 4904, 3717, 100,117, 3587, 1898, 3009, 3171, 1911, 7345, 6068, 119, 1126, 782, 2157,1762, 3983, 1928, 2279, 117, 671, 4275, 756, 4495, 3717, 2419, 1921,119, 4007, 4706, 5679, 3301, 3187, 1962, 6983, 117, 3634, 2552, 2347,2899, 736, 3736, 6804, 119, 102]])}
3 创建模型
#LM:语言模型
#AutoModelForCausalLM 语言模型的加载器
# from transformers import AutoModelForCausalLM, GPT2Model
from transformers import AutoModelForCausalLM
#加载模型
model = AutoModelForCausalLM.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')#查看加载的预训练模型的参数量
print(sum(p.numel() for p in model.parameters()))
102068736
#试算预测一下
with torch.no_grad(): #模型预测时,参数不需要梯度下降#outs是一个元组,包含'loss'(损失)和'logits'(概率)outs = model(**data) outs['logits'].shape
#4:batch_size
#197:每个句子的序列长度
#21128:每个词对应的21128(vocab_size)个词概率
torch.Size([4, 66, 21128])
outs['loss'], outs['logits']
(tensor(8.5514),tensor([[[ -9.9143, -9.7647, -9.8217, ..., -9.6961, -9.7799, -9.6771],[ -7.4731, -8.7423, -8.4802, ..., -8.2767, -8.6411, -9.1488],[ -8.7324, -9.3639, -9.3685, ..., -9.7467, -9.2594, -9.9237],...,[ -3.6951, -3.9939, -4.2000, ..., -4.2021, -4.6660, -4.4627],[ -3.7271, -4.0562, -4.2753, ..., -4.2301, -4.7670, -4.5282],[ -3.6152, -3.9949, -4.1994, ..., -4.1643, -4.6812, -4.4797]],[[ -9.9143, -9.7647, -9.8217, ..., -9.6961, -9.7799, -9.6771],[ -8.5889, -9.2279, -9.2168, ..., -8.6957, -8.1567, -8.5526],[ -8.8908, -8.8825, -8.7488, ..., -9.8976, -9.4964, -10.1446],...,[ -3.8280, -3.7346, -4.4447, ..., -3.8380, -4.3585, -4.2275],[ -4.0099, -3.8985, -4.6581, ..., -3.9868, -4.5486, -4.3698],[ -3.8161, -3.7165, -4.4473, ..., -3.8579, -4.3969, -4.2764]],[[ -9.9143, -9.7647, -9.8217, ..., -9.6961, -9.7799, -9.6771],[ -7.7595, -8.7731, -8.8029, ..., -9.2167, -8.4741, -8.4485],[ -9.1754, -8.8637, -9.1363, ..., -8.7321, -8.7189, -8.9582],...,[ -3.7426, -4.1014, -4.2192, ..., -4.3925, -4.5313, -4.6184],[ -3.8279, -4.2058, -4.3173, ..., -4.4447, -4.6614, -4.6665],[ -3.7733, -4.1448, -4.2570, ..., -4.4249, -4.6132, -4.6397]],[[ -9.9143, -9.7647, -9.8217, ..., -9.6961, -9.7799, -9.6771],[ -6.8225, -7.6599, -7.4913, ..., -7.5897, -7.4440, -7.5681],[ -7.3068, -7.6038, -7.2369, ..., -7.8313, -8.0071, -7.8388],...,[ -5.6309, -5.6956, -5.5271, ..., -5.4339, -5.3443, -5.6756],[ -6.4130, -6.3038, -6.4816, ..., -6.5781, -6.2063, -6.4139],[ -3.6458, -4.0801, -3.7062, ..., -4.2418, -3.9411, -4.0311]]]))
4 训练过程代码
from transformers import AdamW
from transformers.optimization import get_scheduler #学习率的衰减策略#训练
def train():#model是在此函数外部创建的,在此函数内调用前,需要声明model是全局变量global model#设置设备device = 'cuda:0' if torch.cuda.is_available() else 'cpu'#将模型传到设备上model = model.to(device)#创建梯度下降的优化器optimizer = AdamW(model.parameters(), lr=5e-5) #lr=0.00005, -5表示有5位小数#创建学习率衰减计划scheduler = get_scheduler(name='linear', #线性的num_warmup_steps=0, #学习率从一开始就开始衰减,没有预热缓冲期num_training_steps=len(loader), #loader中有多少批数据就训练多少次optimizer=optimizer)model.train()for i, data in enumerate(loader):for k in data.key():#将字典data中每个key所对应的value都传到设备上,再赋值给data[k],相当于把data传到了设备上data[k] = data[k].to(device)#将设备上的data传入模型中,获取输出结果outs(一个字典,包含loss和logits(概率分布))outs = model(**data) #data是一个字典, **data将字典解包成关键字参数传入#从outs中获取损失,在训练过程中观察loss是不是在下降,不下降就是不正常loss = outs['loss']#反向传播loss.backward()#为了梯度下降的稳定性,防止梯度太大,进行梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters, 1.0) #公式中的c最大值就是1#梯度更新optimizer.step()scheduler.step()#梯度清零optimizer.zero_grad()model.zero_grad()if i % 1000 == 0: #每1000个步数,就输出打印内容#下一句诗句是上一句的预测目标真实值,有一个偏移量labels = data['labels'][:, 1:]#预测值outs = outs['logits'].argmax(dim=2)[:, :-1]#筛选条件select = labels != 0 #0是补得pad没有意义,需要筛掉#分别对labels和outs进行筛选labels = labels[select]outs = outs[select]del select #后面这个变量没有用了, 删除防止占用过多内存#计算准确率#labels.numel() 求labels内元素的总个数#.item() 在pytorch中,取出tensor标量的数值cccuracy = (labels == outs).sum().item() / labels.numel() #取出学习率lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), lr, accuracy)train()
5 保存训练好的模型
#保存训练好的模型
# model = model.to('cpu') #将模型传到设备上
# torch.save(model, 'model.pt')
6 加载保存好的模型
# model_2 = torch.laod('../data/model/AI-Poem-save.model')
7 测试预测阶段代码
def generate(text, row, col, model):"""text:传入的数据row, col:预测的诗句是几行几列的model:使用的是哪个模型来预测"""def generate_loop(data):"""循环来预测"""#模型预测时,不需要求导来反向传播with torch.no_grad():outs = model(**data)#从outs中获取分类概率, 输出形状与输入形状一致,所以batch_size在后面# outs形状 [5(五言诗,序列长度), batch_size, vocab_size]outs = outs['logits']#outs形状 [5(五言诗,序列长度), vocab_size]#只取一个元素会把对应的维度降调outs = outs[:, -1] #最后一个是预测值#写诗:预测概率最高的词不一定是最合适的#取出概率较高的前50个#[5, vocab_size] --> [5, 50]topk_value = torch.topk(outs, 50).values #按从小到大排序的#取最后一个就是概率最大的那一个#[5, 50] --> [5] ,升维--> [5, 1]topk_value = topk_value[:, -1].unsqueeze(dim=1)#赋值 # -float('inf')负无穷大 ,表示没有意义outs = outs.masked_fill(outs < topk_value, -float('inf')) #不允许写特殊字符, 将其赋值为负无穷大outs[:, tokenizer.sep_token_id] = -float('inf') #分隔符outs[:, tokenizer.unk_token_id] = -float('inf') #未知字符outs[:, tokenizer.pad_token_id] = -float('inf') #填充padfor i in ',。':outs[:, tokenizer.get_vocab()[i]] = -float('inf')#根据概率做一个无放回的采样:不会出现重复的数据#[5, vocab_size] ---> [5, 1]outs = outs.softmax(dim=1)outs = outs.multinomial(num_sample=1) #从中取一个#强制添加标点c = data['input_ids'].shape[1] / (col + 1)#若c为整数if c % 1 == 0:#若为偶数行if c % 2 == 0:outs[:, 0] = tokenizer.get_vocab()['。']else:outs[:, 0] = tokenizer.get_vocab()[',']#将原始的输入数据和预测的结果拼到一起, 当做下一次预测的输入, 依次循环data['input_ids'] = torch.cat([data['input_ids'], outs], dim=1)data['attention_mask'] = torch.ones_like(data['input_ids'])data['token_type_ids'] = torch.zeros_like(data['input_ids'])data['labels'] = data['input_ids'].clone()# row * col + 1 : 总字数+标点符号if data['input_ids'].shape[1] >= row * col + 1:return datareturn generate_loop(data)#重复三遍:一次生成三首,一次生成的效果可能不太好data = tokenizer.batch_encode_plus([text]*3, return_tensors='pt')data['input_ids'] = data['input_ids'][:, :-1] #最后一个不要data['attention_mask'] = torch.ones_like(data['input_ids'])data['token_type_ids'] = torch.zeros_like(data['input_ids'])data['labels'] = data['input_ids'].clone()data = generate_loop(data)for i in range(3):#一次生成三首,按索引打印输出其中一首print(i, tokenizer.decode(data['input_ids'][i]))
model_2 = torch.load('../data//model/AI-Poem-save.model')generate('秋高气爽', row=4, col=7, model=model_2)
0 [CLS] 秋 高 气 爽 雁 初 飞 , 云 树 高 峰 落 叶 稀 。 人 尽 夜 归 山 外 宿 , 鸡 鸣 霜 月 下 寒 衣 。 1 [CLS] 秋 高 气 爽 木 生 秋 , 何 处 仙 方 未 可 求 。 莫 遣 夜 猿 催 老 去 , 东 风 吹 老 上 林 丘 。 2 [CLS] 秋 高 气 爽 早 蝉 喧 , 清 籁 无 声 响 自 喧 。 野 望 岂 容 云 梦 见 , 江 涵 应 属 月 华 昏 。