大模型开发(三):全量微调项目——基于GPT2 搭建医疗问诊机器人

news/2025/2/27 15:38:33/

全量微调项目——基于GPT2 搭建医疗问诊机器人

  • 0 前言
  • 1 全量微调及项目介绍
    • 1.1 全量微调简介
    • 1.2 项目介绍
    • 1.3 数据介绍
    • 1.4 GPT2模型与硬件配置
  • 2 数据与模型准备
    • 2.1 数据准备
    • 2.2 模型准备
    • 2.3 数据预处理
  • 3 数据集类及其导入器
    • 3.1 dataset.py
    • 3.2 dataloader.py
  • 4 模型配置与推理
    • 4.1 配置文件
    • 4.2 模型推理
  • 5 模型训练
    • 5.1 损失函数
    • 5.2 精度计算
    • 5.3 模型训练
    • 5.4 最终的项目结构
    • 5.5 训练后的推理

0 前言

上一篇文章讲到,大模型都是基于过去的经验数据进行训练完成,它没有学过企业私有的知识,为了处理私有知识,一般可以使用私有知识对模型进行微调,也可以建立本地知识库,然后利用RAG技术实现。
什么时候用微调,什么时候用RAG,有以下几条标准:

1.如果企业里有算力,私有数据量较大,那优先可以微调,时间成本要高;
2.如果没有高的算力,或者数据量小,可以使用RAG;
3.另外如果算力充足,数据量也大,可以实现RAG和微调结合。

总体来讲,RAG技术比较成熟,也比较容易实现,但效果不如微调。

1 全量微调及项目介绍

1.1 全量微调简介

模型的微调有多种,例如全量微调、部分参数微调、参数高效微调、提示词微调、知识蒸馏等,本文介绍全量微调。全量微调(Full Fine-tuning)就是对整个预训练模型的所有参数进行微调,常用于文本生成任务。

1.2 项目介绍

本项目的目标是搭建一个对话机器人,这个机器人使用医疗问诊数据进行微调,使其能实现自动问诊,成为AI医生。
在这里插入图片描述
在这里插入图片描述

1.3 数据介绍

这里用于微调的数据都是一些问诊信息,分别存在于medical_train.txtmedical_valid.txt两个文件中,其中medical_train.txt内容如下:
在这里插入图片描述
medical_train.txt有9万多行,按一个对话回合有三行来算,这里共有三万多个样本。
在这里插入图片描述
medical_valid.txt有1200多行,因此有400个样本。

1.4 GPT2模型与硬件配置

由于计算资源有限,我们这里用GPT2来演示,实际工作中,需要根据算力和需求来选。GPT-2是OpenAI 在2019 年推出的第二代生成式预训练模型,参数量是15亿,权重文件只有三百多兆。

硬件我们使用FunHPC云算力市场上的RTX 3080显卡,显存为12G,关于FunHPC云算力的使用,可以参考这篇文章

2 数据与模型准备

2.1 数据准备

创建一个名为data的文件夹,然后把medical_train.txtmedical_valid.txt放进去,如下图所示:
在这里插入图片描述
这里面upload-data/data包含了我们上传的GPT2模型的相关代码和权重文件。

2.2 模型准备

我们在当前目录下新建一个名为model的目录,把upload-data/data下的gpt2、config、vocab三个文件夹复制到model目录下。
在这里插入图片描述
vocab目录下,包含了两个词表文件,分别是vocab.txt和vocab2.txt,它们分别包含的字符数量为13317和21128。而config目录则包含了一个模型配置文件,名为config.json,内容如下:

{"activation_function": "gelu_new","architectures": ["GPT2LMHeadModel"],"attn_pdrop": 0.1,"bos_token_id": 50256,"embd_pdrop": 0.1,"eos_token_id": 50256,"gradient_checkpointing": false,"initializer_range": 0.02,"layer_norm_epsilon": 1e-05,"model_type": "gpt2","n_ctx": 1024,"n_embd": 768,"n_head": 12,"n_inner": null,"n_layer": 12,"n_positions": 1024,"output_past": true,"resid_pdrop": 0.1,"summary_activation": null,"summary_first_dropout": 0.1,"summary_proj_to_labels": true,"summary_type": "cls_index","summary_use_proj": true,"task_specific_params": {"text-generation": {"do_sample": true,"max_length": 400}},"tokenizer_class": "BertTokenizer","transformers_version": "4.2.0","use_cache": true,"vocab_size": 13317
}

gpt2目录下,则是关于模型的说明文件,这个会早构造模型的时候使用。

2.3 数据预处理

在当前目录下,新建一个名为data_preprocess的python包,内部包含的python脚本如下图所示:
在这里插入图片描述
这里先介绍一下preprocess.py,剩下两个脚本稍后介绍。这个脚本是数据处理的,它将中文句子分词(字),然后再对每个字去词典里查id,最后将每个样本的id保存到pkl文件中,内容如下:

from transformers import BertTokenizerFast # 分词工具
import pickle
from tqdm import tqdm
import osdef data_preprocess(train_txt_path, train_pkl_path):"""对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]""""# 初始化tokenizer(分词器),使用BertTokenizerFast从预训练的中文Bert模型(bert-base-chinese)创建一个tokenizer对象tokenizer = BertTokenizerFast('../model/vocab/vocab.txt',sep_token="[SEP]",    # 分割符pad_token="[PAD]",    # 填充符cls_token="[CLS]",    # 起始符)# 打印词标长度print(f'tokenizer.vocab_size-->{tokenizer.vocab_size}')# 获取ID(即在此表中的索引,例如在词表的第一行,那么id就是0,在词表的第二行,那么id就是1)sep_id = tokenizer.sep_token_id  # 获取分隔符[SEP]的token IDcls_id = tokenizer.cls_token_id  # 获取起始符[CLS]的token IDprint(f'sep_id-->{sep_id}')print(f'cls_id-->{cls_id}')# 读取训练数据集with open(train_txt_path, 'rb') as f:data = f.read().decode("utf-8")  # 以UTF-8编码读取文件内容# 根据换行符区分不同的对话段落(样本之间有两个换行符),需要区分Windows和Linux\mac环境下的换行符if "\r\n" in data:train_data = data.split("\r\n\r\n")     # Windows下换行为\r\n,连续两个换行符用来分割数据else:train_data = data.split("\n\n")# 打印对话段落数量(训练集样本数)print(len(train_data))  # 开始进行tokenize# 保存所有的对话数据,每条数据的格式为:"[CLS]seq1[SEP]seq2[SEP]seq3[SEP]"dialogue_len = []  # 记录所有对话tokenize分词之后的长度,用于统计中位数与均值dialogue_list = []  # 记录所有对话(将每条句子中,每个字符的id组成一个列表,for index, dialogue in enumerate(tqdm(train_data)):# 用换行符来分割问诊内容与回答if "\r\n" in dialogue:sequences = dialogue.split("\r\n")else:sequences = dialogue.split("\n")# 创建一个列表来保存对话内容的id,将起始符的id加入到列表中input_ids = [cls_id]  # 每个dialogue以[CLS]seq1[sep]seq2[sep],因此以[CLS]对应的id开头# 分词器分词的结果,是不带起始符、分割符和填充符的,所以这里需要提前加起始符# 分隔符是在随后的循环里加,填充符是在数据导入器的collate_fn函数中加# 分词(其实这里是分字),并将字符索引加入到列表中for sequence in sequences:# 对问诊/回复内容进行tokenize,并将结果拼接到到input_ids列表中input_ids += tokenizer.encode(sequence, add_special_tokens=False)input_ids.append(sep_id)  # 每个seq之后添加[SEP],表示这条句子结束dialogue_len.append(len(input_ids))  # 将对话的tokenize后的长度添加到对话长度列表中dialogue_list.append(input_ids)  # 将tokenize后的对话添加到对话列表中# #print(f'dialogue_len--->{dialogue_len}')  # 打印对话长度列表print(f'dialogue_list--->{dialogue_list[:2]}')  # 打印前两个样本(对话)的id# 保存数据with open(train_pkl_path, "wb") as f:pickle.dump(dialogue_list, f)if __name__ == '__main__':valid_txt_path = '../data/medical_valid.txt'valid_pkl_path = '../data/medical_valid.pkl'data_preprocess(valid_txt_path, valid_pkl_path)

虽然这里分词器不会给句子自动添加起始符、分隔符和填充符,但我们可以通过分词器拿到这些符号的id。

输出

tokenizer.vocab_size-->13317
sep_id-->102
cls_id-->101
413
100%|███████████████████████████████████████████████████████████████████| 413/413 [00:00<00:00, 2440.13it/s]
dialogue_len--->[205, 34, 123, 20, 26, 23, 287, 26, 24, 28, 226, 20, 292, 100, 136, 106, 72, 22, 277, 239, 23, 32, 233, 96, 38, 204, 221, 285, 77, 211, 158, 134, 263, 21, 49, 108, 255, 44, 25, 253, 23, 21, 36, 22, 22, 288, 278, 230, 48, 247, 191, 126, 26, 233, 231, 21, 179, 213, 29, 20, 22, 29, 29, 37, 138, 54, 121, 25, 185, 21, 28, 223, 23, 35, 184, 25, 24, 103, 66, 195, 244, 182, 175, 254, 254, 132, 184, 277, 92, 19, 284, 23, 71, 179, 130, 261, 77, 24, 28, 220, 30, 214, 28, 285, 18, 59, 28, 209, 19, 218, 29, 24, 31, 121, 99, 26, 208, 97, 236, 27, 28, 279, 158, 68, 25, 34, 19, 130, 95, 28, 20, 63, 151, 24, 196, 185, 27, 19, 29, 19, 33, 242, 204, 188, 51, 82, 23, 20, 198, 52, 227, 23, 28, 22, 31, 25, 247, 26, 58, 30, 28, 213, 201, 24, 24, 255, 24, 23, 35, 19, 26, 233, 300, 34, 48, 58, 30, 25, 30, 77, 189, 213, 32, 32, 163, 226, 33, 39, 224, 24, 58, 59, 24, 267, 270, 275, 27, 17, 193, 29, 23, 178, 138, 28, 217, 20, 222, 91, 35, 258, 97, 26, 227, 184, 117, 286, 37, 115, 152, 32, 31, 116, 27, 77, 26, 210, 30, 27, 121, 284, 87, 26, 179, 135, 91, 34, 19, 299, 32, 195, 46, 214, 198, 47, 66, 35, 232, 167, 174, 70, 25, 194, 107, 56, 261, 113, 252, 232, 21, 24, 25, 240, 36, 25, 26, 99, 205, 71, 55, 84, 41, 223, 241, 36, 198, 23, 185, 225, 170, 113, 25, 226, 24, 38, 196, 253, 69, 63, 69, 26, 181, 20, 35, 25, 41, 225, 19, 234, 22, 29, 278, 272, 42, 96, 212, 90, 234, 269, 183, 220, 226, 29, 181, 20, 17, 24, 25, 22, 26, 252, 28, 22, 19, 235, 26, 21, 270, 21, 289, 59, 27, 78, 203, 55, 185, 27, 237, 21, 22, 225, 29, 75, 61, 71, 20, 24, 26, 225, 27, 226, 27, 289, 22, 21, 283, 262, 22, 41, 19, 192, 34, 284, 200, 26, 32, 120, 39, 22, 22, 255, 23, 69, 282, 56, 22, 230, 219, 136, 249, 31, 58, 37, 146, 220, 35, 21, 252, 214, 232, 231, 153, 220, 35, 27, 40, 68, 30, 36, 236, 232, 70, 251, 23, 151, 35, 21, 19, 251, 218, 289, 22, 37, 2]
dialogue_list--->[[101, 2207, 2111, 758, 2259, 8024, 5553, 1453, 3900, 2349, 5310, 5514, 1920, 8024, 2207, 2111, 758, 2259, 8024, 5553, 1453, 3900, 2349, 5310, 5514, 1920, 8024, 6783, 3890, 1400, 679, 4578, 738, 679, 4173, 749, 8024, 2218, 3221, 3900, 2349, 1920, 749, 8024, 1333, 3341, 130, 119, 12064, 115, 124, 119, 9695, 8175, 4385, 1762, 8108, 8278, 115, 9394, 1348, 1391, 749, 1288, 3299, 1046, 2861, 7450, 5162, 1469, 4347, 1928, 5826, 2990, 1357, 4289, 7578, 5108, 3389, 749, 671, 678, 8024, 1359, 2768, 8108, 8278, 115, 12485, 8175, 3291, 1920, 749, 8024, 2582, 720, 1215, 1435, 8043, 6435, 1278, 4495, 2376, 1221, 8024, 6468, 6468, 102, 872, 1962, 8024, 5440, 5991, 5499, 5143, 5606, 3900, 2349, 5310, 4142, 8024, 3315, 4567, 2382, 680, 677, 1461, 1429, 6887, 2697, 3381, 3300, 5468, 5143, 511, 707, 2414, 6134, 4385, 711, 1355, 4178, 510, 5592, 4578, 510, 1445, 1402, 8024, 2772, 1355, 4495, 5592, 3811, 2772, 912, 4908, 511, 5592, 4578, 3300, 3198, 6496, 5319, 4578, 1762, 1381, 678, 5592, 6956, 8024, 738, 1377, 1762, 1071, 800, 6956, 855, 8024, 3315, 4567, 1914, 2247, 4567, 3681, 2697, 3381, 8024, 671, 5663, 5632, 4197, 4571, 2689, 8024, 1350, 3198, 4638, 2190, 4568, 3780, 4545, 1315, 1377, 102], [101, 3171, 2900, 1086, 3490, 6117, 5052, 1314, 6496, 4638, 2797, 3318, 3780, 4545, 3300, 763, 784, 720, 8043, 102, 3171, 2900, 1086, 3490, 8039, 7474, 5549, 4649, 4480, 3952, 4895, 4919, 3490, 102]]

程序运行后,data目录下将多出一个pkl文件:
在这里插入图片描述
类似的,可以对训练集做相同的处理:

if __name__ == '__main__':train_txt_path = '../data/medical_train.txt'train_pkl_path = '../data/medical_train.pkl'data_preprocess(train_txt_path, train_pkl_path)

3 数据集类及其导入器

我们接着来介绍data目录下的dataset.pydataloader.py,这里需要对PyTorch有基本的了解。

3.1 dataset.py

这里就是自己写一个数据集类,然后继承torch.utils.data.Dataset

# -*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset
import pickleclass MyDataset(Dataset):def __init__(self, input_list, max_len):super().__init__()"""初始化函数,用于设置数据集的属性:param input_list: 输入列表,包含所有对话的tokenize后的输入序列:param max_len: 最大序列长度,用于对输入进行截断或填充"""# print(f'input_list--->{len(input_list)}')self.input_list = input_list  # 将输入列表赋值给数据集的input_list属性self.max_len = max_len  # 将最大序列长度赋值给数据集的max_len属性def __len__(self):return len(self.input_list)def __getitem__(self, index):"""根据给定索引获取数据集中的一个样本:param index: 样本的索引:return: 样本的输入序列张量"""input_ids = self.input_list[index]  # 获取给定索引处的输入序列input_ids = input_ids[:self.max_len]  # 根据最大序列长度对输入进行截断input_ids = torch.tensor(input_ids, dtype=torch.long)  # 将输入序列转换为张量long类型return input_ids  # 返回样本的输入序列张量if __name__ == '__main__':with open('../data/medical_train.pkl', "rb") as f:train_input_list = pickle.load(f)  # 从文件中加载输入列# print(f'train_input_list-->{len(train_input_list)}')# print(f'train_input_list-->{type(train_input_list)}')mydataset = MyDataset(input_list=train_input_list, max_len=300)print(f'mydataset-->{len(mydataset)}')result = mydataset[3]print(result)

输出

mydataset-->30177
tensor([ 101, 7028, 1908, 5524, 5522,  977, 3632, 1355, 5509, 4638, 7770, 1314,1728, 5162, 3300,  763,  784,  720, 8043,  102, 7942,  860, 1216, 5543,679, 6639, 8039, 7770, 7977, 2097, 1967,  102])

3.2 dataloader.py

这个脚本里面有三个函数,load_dataset用于导入数据集,collate_fn用于数据对齐,get_dataloader用于获取数据导入器,学过PyTorch的话,看懂很容易。代码内容如下:

# -*- coding: utf-8 -*-
import torch.nn.utils.rnn as rnn_utils  # 导入rnn_utils模块,用于处理可变长度序列的填充和排序
from torch.utils.data import Dataset, DataLoader
import torch
import pickle
from dataset import MyDataset   # 导入自定义的数据集类is_print = False
def load_dataset(train_path, valid_path):with open(train_path, "rb") as f:train_input_list = pickle.load(f)  # 从文件中加载输入列表with open(valid_path, "rb") as f:valid_input_list = pickle.load(f)  # 从文件中加载输入列表train_dataset = MyDataset(train_input_list, 300)  # 创建训练数据集对象val_dataset = MyDataset(valid_input_list, 300)  # 创建验证数据集对象return train_dataset, val_dataset  # 返回训练数据集和验证数据集def collate_fn(batch):"""自定义的collate_fn函数,用于将数据集中的样本进行批处理:param batch: 样本列表:return: 经过填充的输入序列张量和标签序列张量"""if is_print:print(f'当前batch中,最长的句子长度为:', max([sequence.shape for sequence in batch]))print('处理前')print(f'batch的第一个样本的长度--》{batch[0].shape}')print(f'batch的第二个样本的长度--》{batch[1].shape}')print(f'*'*80)#rnn_utils.pad_sequence:将根据一个batch中,最大句子长度,进行补齐# 对输入序列进行填充(填充0),使其长度一致input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0)  if is_print:print('处理后')print(f'batch的第一个样本的长度--》{input_ids[0].shape}')print(f'batch的第二个样本的长度--》{input_ids[1].shape}')print(f'*'*80)# 对标签序列进行填充,使其长度一致,补充的位置不去计算损失,-100作为“不计算损失”的标志labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)  return input_ids, labels  # 返回经过填充的输入序列张量和标签序列张量def get_dataloader(train_path, valid_path, batch_size):"""获取训练数据集和验证数据集的DataLoader对象:return: 训练数据集的DataLoader对象和验证数据集的DataLoader对象"""# 加载训练数据集和验证数据集train_dataset, val_dataset = load_dataset(train_path, valid_path)  train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn,drop_last=True)  # 创建训练数据集的DataLoader对象validate_dataloader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn,drop_last=True)  # 创建验证数据集的DataLoader对象return train_dataloader, validate_dataloader  # 返回训练数据集的DataLoader对象和验证数据集的DataLoader对象if __name__ == '__main__':train_path = '../data/medical_train.pkl'valid_path = '../data/medical_valid.pkl'train_dataloader, validate_dataloader = get_dataloader(train_path, valid_path, batch_size=8)is_print = True  # 打印collate_fn函数中的信息for input_ids, labels in train_dataloader:print(f'input_ids--->{input_ids.shape}')print(f'labels--->{labels.shape}')break

输出

当前batch中,最长的句子长度为: torch.Size([207])
处理前
batch的第一个样本的长度--》torch.Size([26])
batch的第二个样本的长度--》torch.Size([173])
********************************************************************************
处理后
batch的第一个样本的长度--》torch.Size([207])
batch的第二个样本的长度--》torch.Size([207])
********************************************************************************
input_ids--->torch.Size([8, 207])
labels--->torch.Size([8, 207])

4 模型配置与推理

4.1 配置文件

我们在当前目录下新建一个名为parameter_config.py的文件,用于设置相关参数,内容如下:

#-*- coding: utf-8 -*-
import os
import torchclass ParameterConfig():def __init__(self):# 判断是否使用GPUself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 词典路径:在vocab文件夹里面self.vocab_path = './model/vocab/vocab.txt'# 训练文件路径self.train_path = 'data/medical_train.pkl'# 验证数据文件路径self.valid_path = 'data/medical_valid.pkl'# 模型配置文件self.config_json = './model/config/config.json'# 模型保存路径self.save_model_path = 'save_model'# 如果你有预训练模型就写上路径(我们本次没有直接运用GPT2它预训练好的模型,而是仅只用了该模型的框架)self.pretrained_model = ''# 要忽略的字符索引,因为有些字符需要补齐长度,补的时候用-100来填充,填充的部分不计算损失函数,也不计算精度self.ignore_index = -100# 历史对话的长度(即问+答的数量,而非“问答对”的数量)self.max_history_len = 3# "dialogue history的最大长度"# 每一个完整对话的句子最大长度self.max_len = 300  # 每个utterance的最大长度,超过指定长度则进行截断# 重复惩罚参数,若生成的对话重复性较高,可适当提高该参数self.repetition_penalty = 10.0 # top-k取词策略中的kself.topk = 4# 训练参数self.batch_size = 8self.epochs = 100self.lr = 2.6e-5# AdamW优化器的eps,在计算梯度的时候,为了增加数值计算的稳定性而加到分母里的项,其为了防止在实现中除以零self.eps = 1.0e-09# 梯度上限,用于进行梯度裁剪,防止梯度爆炸self.max_grad_norm = 2.0# 多少步打印一次loss,这里的“步”指的是反向传播的次数self.loss_step = 1 # 多少步更新一次参数,这里的“步”指的是反向传播的次数self.gradient_accumulation_steps = 4	# warmup达到最大学习率的步数self.warmup_steps = 100 # 使用Warmup预热学习率的方式,即先用最初的小学习率训练,然后每个step增大一点点,直到达到最初设置的比较大的学习率时(注:此时预热学习率完成),采用最初设置的学习率进行训练,有助于使模型收敛速度变快,效果更佳。if __name__ == '__main__':pc = ParameterConfig()print(pc.train_path)print(pc.device)print(torch.cuda.device_count())

输出:

data/medical_train.pkl
cuda
1

这里有一些参数暂时没有理解没关系,看了后面的代码就能理解了。

4.2 模型推理

在当前目录下新建一个名为inference.py的文件,用于模型推理,下面的代码有点复杂,我已通过注释尽量降低阅读的难度,这部分代码还是比较重要的,这里如果看懂了,那么大模型开发中的 “输入问题处理” 和 “推理结果后处理” 基本也就掌握了。

import os
from datetime import datetime
from transformers import GPT2LMHeadModel
from transformers import BertTokenizerFast
import torch.nn.functional as F
from parameter_config import *PAD = '[PAD]'
pad_id = 0def top_k_top_p_filtering(logits, top_k=0, filter_value=-float('Inf')):"""不需要掌握,了解即可使用top-k和/或nucleus(top-p)筛选来过滤logits的分布,这里只演示top-k参数:logits: logits的分布,形状为(词汇大小)top_k > 0: 保留概率最高的top k个标记(top-k筛选)。)。"""assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less cleartop_k = min(top_k, logits.size(-1))  #确保top_k不超过logits的最后一个维度大小,即top_k不超过词汇长度if top_k > 0:# 移除概率小于top-k中的最后一个标记的所有标记indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]# torch.topk()返回最后一维中最大的top_k个元素,返回值为(values, indices),values为最大的k个元素,indices为最大的k个元素对应的下标# torch.topk(logits, top_k)[0][..., -1, None]相当于values[..., -1, None]# values[..., -1]表示选择张量 values 的最后一个元素,即k个选中的元素中最小的,因为values是从大到小排列# a[..., -1, None]则表示在最后的结果上增加一个维度,这样就能在和logits比较时进行广播操作# 最后得到的 indices_to_remove 是一串布尔索引,用于标记需要被过滤掉的id# 对于topk之外的其他元素的logits值设为负无穷logits[indices_to_remove] = filter_valuereturn logitsdef main():pconf = ParameterConfig()device = 'cuda' if torch.cuda.is_available() else 'cpu'print('using device:{}'.format(device))# 创建分词器,这里创建分词器的参数,要和data_preprocess/preprocess.py中的分词器参数完全一致tokenizer = BertTokenizerFast(vocab_file=pconf.vocab_path,sep_token="[SEP]",pad_token="[PAD]",cls_token="[CLS]")# 创建模型对象,并转移到指定设备,调整为评估状态model = GPT2LMHeadModel.from_pretrained('./save_model/min_ppl_model_bj')model = model.to(device)model.eval()history = []print('开始和我的助手小医聊天:')while True:try:# 获取输入text = input("user:")# 对输入的句子分词,获得各个字符对应的idtext_ids = tokenizer.encode(text, add_special_tokens=False)# 将输入句子中各个字符的id加入到对话历史中history.append(text_ids)# 构建输入的id列表input_ids = [tokenizer.cls_token_id]  # 每个input以[CLS]为开头# 构建喂给模型的完整输入# 因为输入模型的,不单单只有当前的提问信息,还需要有最近的一部分历史对话信息# 所以这里是从history获取最近的一部分对话信息的id# 这里要确保喂给模型的数据格式为:[CLS]seq1[SEP]seq2[SEP]seq3[SEP]# pconf.max_history_len是每次提问时,需要考虑的历史对话长度(问+答的合计数量,不是“问答对”的数量)for history_id, history_utr in enumerate(history[-pconf.max_history_len:]):input_ids.extend(history_utr)               # 将历史对话中的句子(id列表)加入到 input_ids 中input_ids.append(tokenizer.sep_token_id)    # 添加分隔符# 将喂给模型的输入转换成张量input_ids = torch.tensor(input_ids).long().to(device)# 添加batch_size维度input_ids = input_ids.unsqueeze(0)  # 这条执行之后,input_ids的形状为(1, seq_len),其中seq_len为当前input_ids的长度# 根据context,生成的responseresponse = []# 最多生成max_len个token,# 模型输入的时候是一串字符的id,但输出确是一个概率列表,即词表中每个字符的概率,# 也就是说,模型每次前向传播,只生成一个字符for _ in range(pconf.max_len):# 获取模型的输出outputs = model(input_ids=input_ids)logits = outputs.logits# logits的形状为(1, seq_len, vocab_size),seq_len为当前input_ids的长度,vocab_size为词表大小# 你给模型每输入一个字符,都会输出下一个字符的概率列表,这里输入了seq_len个,所以是(1, seq_len, vocab_size)# 例如,若logits[0, 5, 3]为0.5,则表示在第5个字符输入之后,模型认为,下一个字符为词表中id为3的token的概率为0.5# 当然,这里说logits是概率并不准确,它需要先取top-k或者top-p,然后再softmax# 最后一个字符输入之后,生成的下一个字符的概率值,即词表中各个字符的概率next_token_logits = logits[0, -1, :]# 循环的第一轮,response为空# 从循环第二轮开始,需要对生成字符的概率分布进行惩罚,因为如果模型生成了重复的字符,那么这个重复的字符的概率应该降低for id in set(response):# 这里的id是模型在之前推理的时候已经出现过的字符的id,# 为了避免接下来要生成的字符与前面已经生成过的字符一样,所以这里需要对这些字符的概率进行惩罚next_token_logits[id] /= pconf.repetition_penalty# 对于[UNK]的概率设为无穷小,也就是说模型生成的下一个词不可能是[UNK]这个tokenunk_id = tokenizer.convert_tokens_to_ids('[UNK]')   # 获取[UNK]的idnext_token_logits[unk_id] = -float('Inf')# 使用top-k和/或nucleus(top-p)筛选来过滤logits的分布,这里使用top-kfiltered_logits = top_k_top_p_filtering(next_token_logits, top_k=pconf.topk)# 这里的filtered_logits,只有top-k对应位置有概率值,其他字符对应的概率值都为无穷小# 从k个候选字符中,随机抽取一个字符作为下一个字符,这里获得的是对应字符的idnext_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)# 通过softmax可以对概率值进行归一化# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高# 遇到[SEP]则表明response生成结束if next_token == tokenizer.sep_token_id:break# 如果没有结束,则把选中的字符对应的 id 加入到 response 列表中response.append(next_token.item())# 将选中的字符对应的id加入到输入中,作为下一轮的输入input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)# 将response加入到对话历史# 上述循环执行完毕后,response就是模型生成的回答,它是一个id列表history.append(response)# 将response转换为token,即把id列表转换成实际字符,然后打印出来text = tokenizer.convert_ids_to_tokens(response)print("chatbot:" + "".join(text))except KeyboardInterrupt:breakif __name__ == '__main__':main()

因为目前在目录save_model/min_ppl_model_bj下还没有模型,所以上述代码暂时跑不通。之所以现在就讲推理,主要是因为需要先知道模型的输出是什么,否则后面的损失函数不好介绍。
这里pconf.max_history_lenpconf.repetition_penaltypconf.max_len参数,在前面参数配置时可能不明白,但看懂了这个脚本中的代码,基本就懂了。

这里需要重点掌握模型的输出,即outputs.logits的含义,因为涉及到后面的损失函数与目标值之间的配对。

5 模型训练

5.1 损失函数

在当前目录下新建一个名为function_tools.py的文件,在这个文件下,我们要加入两个函数,分别是损失函数和精度计算函数。
其中,损失函数的代码如下,从代码上可以看到,如果不做标签平滑,那计算损失函数的步骤非常简单,如果要做标签平滑,那么计算过程稍微复杂了一些,我花了很长时间做注释,尽量降低阅读的难度。

#-*- coding: utf-8 -*-
import torch
import torch.nn.functional as Fdef caculate_loss(logit, target, pad_idx, smoothing=False):'''计算模型的损失:通过函数解析下,GPT2内部如何计算损失的:param logit: 模型预测结果,形状为 (batch_size, seq_len, vocab_size),seq_len为input_ids的长度,vocab_size为词表大小:param target: 真实标签,形状为 (batch_size, seq_len):param pad_idx: 需要忽略的索引:param smoothing: 是否进行标签平滑处理:return:'''if smoothing:# 预测值与标签进行数据对齐,序列的第一个字符输入模型,模型的输出理应是序列的第二个字符# 因此第一个字符输入后,模型输出应该与第二个字符进行比较,因此target从第二个字符开始,如果对RNN或者LSTM比较熟悉,很容易理解# 若某条句子输入后,logit的形状为(1, seq_len, vocab_size),那么logit最后一个字符没有对应的标签,# 因此计算损失函数时,logit只需要取出倒数第二个字符进行比较logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2)) # 三维变形成二维,方便计算损失target = target[..., 1:].contiguous().view(-1)eps = 0.1		# 标签平滑系数n_class = logit.size(-1)    # 词表中的字符数# 将标签值进行one-hot编码one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1)# 因为前面已经把logit从三维变成了二维,所以这里logit的维度为(batch_size * seq_len, vocab_size)# target是目标字符的id,即在词表中的索引# 这里是沿着第1维进行scatter操作,即对每行,以target为列索引,然后把1给映射过去,实际上就是在目标字符的位置上标1# 平滑处理,让原本为 1 的位置减少一些,而原本为 0 的位置增加一些# 这样做的目的是让模型在训练时不过分依赖于单一的正确类别,而是考虑其他类别的可能性,从而提高泛化能力。one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)# 预测值进行log softmax处理,方便计算交叉熵损失log_prb = F.log_softmax(logit, dim=1)# 非填充标记,以过滤掉 pad_idx 对损失的贡献(稍后会看到)non_pad_mask = target.ne(pad_idx)# ne 是 "not equal"(不等于)的缩写,用于比较两个张量或标量是否不相等,# 它返回一个布尔张量,其中每个元素表示对应位置的值是否满足“不等于”的条件# 计算交叉熵损失loss = -(one_hot * log_prb).sum(dim=1)# 选择那些标签不是 pad_idx 的损失,并求平均值loss = loss.masked_select(non_pad_mask).mean()  # average laterelse:logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))labels = target[..., 1:].contiguous().view(-1)loss = F.cross_entropy(logit, labels, ignore_index=pad_idx)return loss

损失函数的计算过程最好也要掌握,在自然语言处理领域,都是使用交叉熵来计算损失函数,上面的代码掌握了,基本上NLP领域的交叉熵都掌握了。

5.2 精度计算

上面的损失函数代码看懂之后,精度计算的代码也就很容易了。

def calculate_acc(logit, labels, ignore_index=-100):"""计算准确率,忽略特定索引的预测结果。参数:logit (Tensor): 模型的预测输出,形状为 (batch_size, seq_len, vocab_size)。labels (Tensor): 实际标签,形状为 (batch_size, seq_len)。ignore_index (int): 需要忽略的索引,默认为 -100。返回:n_correct (int): 预测正确的单词数量。n_word (int): 不包括忽略索引的总单词数量。"""# 调整预测输出和标签的形状,以便进行比较logit = logit[:, :-1, :].contiguous().view(-1, logit.size(-1))labels = labels[:, 1:].contiguous().view(-1)# 对每个预测字符,取出最大概率值以及对应索引_, logit = logit.max(dim=-1)  # 对于每条数据,返回最大的index# 创建一个掩码,忽略特定索引的预测结果,即以过滤掉 ignore_index 对精度的影响non_pad_mask = labels.ne(ignore_index)# 计算预测正确的字符数量n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item()# logit.eq(labels)返回布尔索引,该索引将 logit 等于 labels 的位置标记为 True,不相等的位置标记为 False# 计算标签中的总字符数,不包括填充字符(即ignore_index对应的字符)n_word = non_pad_mask.sum().item()return n_correct, n_word

5.3 模型训练

终于到了模型训练了,在当前目录下新建一个名为train.py的脚本,先假如如下代码:

#-*- coding: utf-8 -*-
import os
import torch
from datetime import datetime
import transformers
from transformers import GPT2LMHeadModel, GPT2Config    # 配置定义GPT2模型
from transformers import BertTokenizerFast              # 使用BERT的分词器import sys
sys.path.append('data_preprocess/')
from functions_tools import *           # 导入自定义的工具类函数(计算损失和准确率)
from parameter_config import *          # 导入项目的配置文件(训练数据集路径和训练的轮次参数等)
from data_preprocess.dataloader import *                # 导入数据:dataloaderdef main():# 初始化配置参数params = ParameterConfig()# 初始化tokenizertokenizer = BertTokenizerFast(params.vocab_path,sep_token="[SEP]",pad_token="[PAD]",cls_token="[CLS]")# 创建模型的保存目录# 如果没有创建会自动的创建输出目录if not os.path.exists(params.save_model_path):os.mkdir(params.save_model_path)# 创建模型if params.pretrained_model:  # 加载预训练模型(如果有)model = GPT2LMHeadModel.from_pretrained(params.pretrained_model)else:  # 初始化模型model_config = GPT2Config.from_json_file(params.config_json)model = GPT2LMHeadModel(config=model_config)# 移到指定设备model = model.to(params.device)# 确认模型配置的词表长度,和分词器所用的词表,长度是否一致assert model.config.vocab_size == tokenizer.vocab_size# 计算模型参数数量num_parameters = 0parameters = model.parameters()for parameter in parameters:num_parameters += parameter.numel()print(f'模型参数总量---》{num_parameters}')# 加载训练集和验证集train_dataloader, validate_dataloader = get_dataloader(params.train_path, params.valid_path, params.batch_size)# 训练train(model, train_dataloader, validate_dataloader, params)if __name__ == '__main__':main()

这里有个train()函数,内容如下:

def train(model,  train_dataloader, validate_dataloader, args):# 计算整体训练完,需要迭代的步数t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs# 创建优化器,transformer系列的模型,都使用AdamW优化器# eps,在计算梯度的时候,为了增加数值计算的稳定性而加到分母里的项,其为了防止在实现中除以零optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)# 学习率调度器,用于学习率预热,线性增加学习率scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)print('starting training')# 用于记录每个epoch训练和验证的losstrain_losses, validate_losses = [], []# 记录验证集的最小loss(遇见比10000更小的,就替换)best_val_loss = 10000# 开始训练for epoch in range(args.epochs):# 训练train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,epoch=epoch, args=args)train_losses.append(train_loss)# 验证validate_loss = validate_epoch(model=model, validate_dataloader=validate_dataloader,epoch=epoch, args=args)validate_losses.append(validate_loss)# 保存当前困惑度最低的模型,困惑度低,模型的生成效果不一定会越好# 验证集损失越小,证明生成的句子越接近标签,从而句子越通顺,模型困惑度越低if validate_loss < best_val_loss:best_val_loss = validate_lossprint('saving current best model for epoch {}'.format(epoch + 1))model_path = os.path.join(args.save_model_path, 'min_ppl_model_bj'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model.save_pretrained(model_path)

这里面每个epoch的训练和验证都封装成了函数,其中训练如下:

def train_epoch(model,train_dataloader,optimizer, scheduler,epoch, args):''':param model: :param train_dataloader: :param optimizer::param scheduler: 调度器:param epoch: 当前的轮次:param args: 模型配置文件的参数对象:return:'''model.train()device = args.deviceignore_index = args.ignore_index# 记录该epoch训练开始时间(每个epoch训练时,这个变量都会更新,如果不保存就用不到)epoch_start_time = datetime.now()# 设定一个变量保存整个epoch的loss总和total_loss = 0# epoch_correct_num: 每个epoch中,output预测正确的字符的数量# epoch_total_num: 每个epoch中,output预测的字符的总数量epoch_correct_num, epoch_total_num = 0, 0for batch_idx, (input_ids, labels) in enumerate(train_dataloader):input_ids = input_ids.to(device)labels = labels.to(device)# 模型前向传播# 如果对模型输入不仅包含input还包含标签,那么得到结果直接就有loss值# 如果对模型的输入只有input,那么模型的结果不会含有loss值,此时,可以自定义函数来计算损失outputs = model.forward(input_ids, labels=labels)# 获得损失loss = outputs.lossloss = loss.mean() # 统计该batch的预测token的正确数与总数logits = outputs.logitsbatch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_num# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该epoch的总损失total_loss += loss.item()# args.gradient_accumulation_steps 是需要累积的步数,这里是进行一定step的梯度累计之后,再更新参数# batch_size越大越好,受异常值影响小,模型就稳定,但受硬件的限制,batchsize没办法很大,因此这里使用梯度累积# 即积累多个batch_size的梯度后,才更新一次参数,变相增大了batch_size# 而每次计算得到的loss,都是除了batch_size的,即它是单样本的,# 如果这里不除以gradient_accumulation_steps,那么loss将是包含多个样本的损失if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_steps# 反向传播loss.backward()# 梯度裁剪,避免发生梯度爆炸torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 累积到一定step后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()# 累积到一定step后,打印信息if (batch_idx + 1) % args.loss_step == 0:print("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputs# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numprint("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# 保存模型if epoch % 10 == 0 or epoch == args.epochs:print('saving model for epoch {}'.format(epoch + 1))model_path = os.path.join(args.save_model_path, 'bj_epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)# 保存模型model.save_pretrained(model_path)print('epoch {} finished'.format(epoch + 1))# 获取该epoch训练结束时间epoch_finish_time = datetime.now()# 打印用时print('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

细心的同学可能会发现,这里我们并没有用到我们自己写的损失函数,而前面之所以介绍损失函数,是为了理解大模型中损失函数的计算过程。

此外,这里还保存了一次模型,前面讲tran()函数的时候,也保存了一次模型。train函数保存的是困惑度最低的模型,train_epoch则是在epoch能被10整除时保存模型。

训练函数的代码看懂了,验证函数就很容易了:

def validate_epoch(model, validate_dataloader, epoch, args):model.eval()device = args.deviceignore_index = args.ignore_indexprint("start validating")# 记录验证开始时间epoch_start_time = datetime.now()# 定义一个变量用于收集验证集上的损失total_loss = 0# 使用with结构,可以捕获cuda out of memory exceptionwith torch.no_grad():for batch_idx, (input_ids, labels) in enumerate(validate_dataloader):input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()total_loss += loss.item()del input_ids, outputs      # 释放显存# 记录当前epoch的平均loss,并打印epoch_mean_loss = total_loss / len(validate_dataloader)print("validate epoch {}: loss {}".format(epoch+1, epoch_mean_loss))# 记录验证结束时间,并打印用时epoch_finish_time = datetime.now()print('time for validating one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

至此,代码介绍完毕。train.py这个脚本,因为要执行100个epoch,执行完可能需要一天的时间。

5.4 最终的项目结构

在这里插入图片描述
这里面save_model目录下应该有很多bj_epochX,其中X为数字,因为每10个epoch保存一次,但我这里为了方便显示,删掉了9个。

5.5 训练后的推理

现在我们已经训练好了模型,save_model/min_ppl_model_bj下已经有模型权重了,可以进行推理了,运行inference.py:

using device:cuda
开始和我的助手小医聊天:
user:头疼了怎么办?
chatbot:最近不知道怎么回事睡觉,特别影响板式效率、白天也很正常。晚上一直疼痛的厉害就是感到了凌晨头和手背都会有刺激性东西堵着地下来后又已经开始干呕但这几次我还没去医院做过仔细检查才发现自己得病因而言之前只能看见好些[CLS]较远总要办?非子在10岁记忆力想比成径法里面系时候或老公每年轻微腿跟外臂却可以吃太久动型水肿呢阻出今期间隔切乎右眼与主肚腹腔往点内注意义并说明显示患者为情况所谓另问
user:感冒了怎么办?
chatbot:根据你的病史推测,考虑是由于感冒了。可以口服头孢克洛颗粒治疗一下就好点儿!建议去医院检查几项看如何为阴性?[CLS]传染吗?临床实施您先做过敏原理解决问题药我认晓这种怎么办呀.必要到大概需仔细观察排除其他得不想经开始干呕但那定时没有送进行咨询见效果之前只能够说话来发觉着呢因而言另外再次出现自己患者会诊断和
user:

可以看到,我们训练的模型,还不够只能,一个原因是100个epoch并不能使训练充分,另一个原因是语料库比较小。


http://www.ppmy.cn/news/1575278.html

相关文章

【蓝桥杯】每天一题,理解逻辑(1/90)【Leetcode 移动零】

文章目录 题目解析讲解算法原理【双指针算法思路】(数组下标充当指针)如何划分和执行过程大致 代码详情 题目解析 题目链接&#xff1a;https://leetcode.cn/problems/move-zeroes/description/ 题目意思解析 把所有的零移动到数组的末尾保持非零元素的相对顺序 理解了这两层…

Oracle中补全时间的处理

在实际数据处理的过程中&#xff0c;存在日期不连续的问题&#xff0c;可能会导致数据传到前后端出现异常&#xff0c;为了避免这种问题&#xff0c;通常会从数据端进行日期不全的处理&#xff1a; 以下为补全年份的案例&#xff1a; with x as (select 开始年份 &#xff08;…

数字可调控开关电源设计(论文+源码)

1 设计要求 在本次数字可调控开关电源设计过程中&#xff0c;对关键参数设定如下&#xff1a; &#xff08;1&#xff09;输入电压&#xff1a;DC24-26V,输出电压&#xff1a;12-24&#xff08;可调&#xff09;&#xff1b; &#xff08;2&#xff09;输出电压误差&#xf…

开源机器学习框架

TensorFlow 是由谷歌开发的一个开源机器学习框架&#xff0c;用于构建和训练深度学习模型。它的核心概念是张量&#xff08;Tensor&#xff09;&#xff0c;即多维数组&#xff0c;用于表示数据。TensorFlow 中的计算以数据流图的形式表示&#xff0c;图中的节点表示各种数学操…

【蓝桥杯集训·每日一题2025】 AcWing 5437. 拐杖糖盛宴 python

5437. 拐杖糖盛宴 Week 2 2月25日 题目描述 农夫约翰的奶牛们非常爱吃甜食&#xff0c;尤其爱吃拐杖糖。 约翰一共有 N N N 头奶牛&#xff0c;编号 1 ∼ N 1 \sim N 1∼N&#xff0c;其中第 i i i 头奶牛的初始高度为 a i a_i ai​。 约翰给奶牛们准备了 M M M 根拐杖…

Linux | GRUB / bootloader 详解

注&#xff1a;本文为 “Linux | GRUB / bootloader” 相关文章合辑。 英文引文&#xff0c;机翻未校。 图片清晰度限于引文原状。 未整理去重。 What is Grub in Linux? What is it Used for? Linux 中的 Grub 是什么&#xff1f;它的用途是什么&#xff1f; Abhishek …

LangChain:Models、Prompts、Indexes、Memory、Chains、Agents。MaxKB

LangChain:Models、Prompts、Indexes、Memory、Chains、Agents 在LangChain框架中,Models、Prompts、Indexes、Memory、Chains、Agents是六大核心抽象概念,它们各自承担独特功能,相互协作以助力开发者基于大语言模型构建高效智能应用。 Models(模型):指代各类大语言模型…

1.1部署es:9200

安装es&#xff1a;root用户&#xff1a; 1.布署java环境 - 所有节点 wget https://d6.injdk.cn/oraclejdk/8/jdk-8u341-linux-x64.rpm yum localinstall jdk-8u341-linux-x64.rpm -y java -version 2.下载安装elasticsearch - 所有节点 wget ftp://10.3.148.254/Note/Elk/…