【Instruction Tuning】ChatGLM 微调实战(附源码)

news/2024/11/19 9:33:10/

在之前的文章中,我们已经讲过了 ChatGPT 的三个主要流程:

  1. SFT:通过 Instruction Tuning 来微调一个监督学习模型
  2. Reward Model:通过排序序列来训练一个打分模型
  3. Reinforcement Learning:通过强化学习来进一步优化模型。

何枝:【RLHF】想训练ChatGPT?得先弄明白Reward Model怎么训(附源码)409 赞同 · 37 评论文章正在上传…重新上传取消

前两篇文章主要对 RM 和 RL 两部分进行了讲解和实验,

但无数的经验向我们证明 —— 拥有一个好的 SFT 的模型对后两步的训练至关重要。

由于在 RL 训练过程中会加入与 SFT 模型的相似度(KL-Divergence)惩罚,

这意味着 RL 模型的上限很大程度上取决于 SFT 模型

为此,我们今天来重点讲一讲如何通过 ChatGLM 来微调一个读懂我们指令的模型。

1. GLM Backbone

Paper Link:  https://arxiv.org/pdf/2103.10360.pdf

在讲微调代码之前,我们先来看看 GLM 的基本架构。

我们都知道,目前主流的两种 Backbone:一类是以 BERT 为首的 Encoder 架构(双向注意力),另一种是以 GPT 为首的 Decoder 架构(单向注意力)。

这两种架构各有各的好处,一个更适合做理解,一个更适合做生成

那么如何将这两种模型做合并,集二者优势于一身,是近年来人们一直在尝试的努力(如:T5、BART等)。

不同于 Encoder-Decoder 的堆叠,GLM 通过一种巧妙的 2D Position Embedding,并通过 Attention MASK 来使得模型在训练时 「既能在部分内容上存在双向注意力」「又能在生成任务中保持单向注意力」。

以下是 GLM 示意图:

GLM Position Embedding 示意图

  1. 首先,从原始句子中 Random Sample 出来一些 Span 用于并 [MASK] 掉(该思想源自 BERT),注意:这里是以 Span 维度进行 MASK 的。
  2. 将原句子分为两组,PART A 是原句子,只不过句子中被挑选出来的 Span 用 [MASK] 符号代替;PART B 是挑选出来的 Span 集合
  3. 将挑选出来的 MASK Span 集合(PART B)拼接在原句子(PART A)后面,注意:这里是先对 PART B 做乱序后,再拼接到句子后面(目的是为了训练 Position Embedding)。
  4. 设计 2D Position:这是我认为比较有趣的设定,位置编码分成了两组。一组用于表征「全局位置」,被挑选出的「MASK SPAN」中的所有 token 的位置索引都等于整个 Span 在原句子中的位置(例如:x5, x6 的索引都是 5);而另一组用来专门表征 MASK Span 内部 token 的相对位置编码(例如:x5, x6 的索引这两个 token 在 Mask Span 中的相对位置)。
  5. 通过设置 Attention MASK,使得 PART A 中的内容是双向可见的,且 PART B 中所有 token 也可以看到 Part A 中的内容;而对于 PART B 中的内容保持单向可见。
  6. 通过对 Part B 中的内容做「生成任务」来进行模型迭代。

以上便是我认为 GLM 中最关键的几个点。

2. Finetune GLM

2.1 数据集准备

我们以信息抽取任务为例,将一个信息抽取数据集(DuIE)添加上 Instruction,以此来教会 ChatGLM 根据我们的指令来完成抽取任务。

我们仿照 Alpaca 数据集,将数据结构设为以下形式:

{"instruction": "你现在是一个很厉害的阅读理解器,找到句子中的三元组信息并输出成json给我。","input": "九玄珠是在纵横中文网连载的一部小说,作者是龙马。","target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}

进一步的,我们将 instruction 和 input 字段合并,得到如下数据:

{"context": "Instruction: 你现在是一个很厉害的阅读理解器,找到句子中的三元组信息并输出成json给我:。\nInput: 九玄珠是在纵横中文网连载的一部小说,作者是龙马。\nAnswer: ", "target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}

其中,

  • Instruction:存放我们希望模型做的任务的指令
  • Input:存放我们喂给模型的任务数据
  • Target:存放模型的输出标签

2.2 Label 构建

将数据集解析为训练 label 的代码如下:

def convert_example(examples: dict, tokenizer,max_source_seq_len: int,max_target_seq_len: int,):"""将样本数据转换为Ptuning模型接收的输入数据。Args:examples (dict): 训练数据样本, e.g. -> {"text": ['{"context": "年基准利率4.35%。从实际看...", "target": "2017年银行贷款基准利率"}',...]}max_source_seq_len (int): prompt最大长度max_target_seq_len (int): 答案最大长度Returns:dict (str: np.array) -> tokenized_output = {'input_ids': [[1525, 10, ...], [758, 2345, ...]], 'labels': [[822, 10, ...], [125, 58...]]}"""tokenized_output = {'input_ids': [],'labels': []}max_seq_length = max_source_seq_len + max_target_seq_lenfor example in examples['text']:try:example = json.loads(example)context = example["context"]target = example["target"]prompts_ids = tokenizer.encode(text=context,add_special_tokens=False)target_ids = tokenizer.encode(text=target,add_special_tokens=False)                    if len(prompts_ids) >= max_source_seq_len:                                          # source 需要留一个 [gMASK] token 在结尾prompts_ids = prompts_ids[:max_source_seq_len - 1]if len(target_ids) >= max_target_seq_len - 1:                                       # target 需要留一个 <sop> 在开头和一个 <eop> token 在结尾target_ids = target_ids[:max_target_seq_len - 2]input_ids = tokenizer.build_inputs_with_special_tokens(prompts_ids, target_ids)     # source_ids + [gMASK] + <sop> + target_ids + <eop>context_length = input_ids.index(tokenizer.bos_token_id)                            # bos 在 target 的第一位mask_position = context_length - 1                                                  # [gMASK] 在 source 的最后一位labels = [-100] * context_length + input_ids[mask_position + 1:]                    # 从 bos 开始到后面所有的 target 到 eos 都为 labelpad_len = max_seq_length - len(input_ids)input_ids = input_ids + [tokenizer.pad_token_id] * pad_lenlabels = labels + [-100] * pad_lentokenized_output['input_ids'].append(input_ids)tokenized_output['labels'].append(labels)except:print(f'"{example}" -> {traceback.format_exc()}')continuefor k, v in tokenized_output.items():tokenized_output[k] = np.array(v)return tokenized_output

其中,

  • max_source_seq_len 用于设定模型接收的最大输入长度
  • max_target_seq_len 用于设定模型输出的最大长度

2.3 模型训练

ChatGLM 的微调存在 LoRA Finetune 和 P-Tuning 两种微调方式。

P-Tuning V.S. LoRA

这两种方式都可以使得 ChatGLM-6B 的模型能在 32G 的 V100 上进行微调训练。

通过以下两种参数配置即可选择使用 P-Tuning 还是 LoRA:

# LoRA Finetune
python train.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_lora True \--lora_rank 8 \--batch_size 1 \--num_train_epochs 2 \--save_freq 1000 \--learning_rate 3e-5 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints/finetune \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM Fine-Tune" \--device cuda:0# P-Tuning
python train.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_ptuning True \--pre_seq_len 128 \--batch_size 1 \--num_train_epochs 2 \--save_freq 200 \--learning_rate 2e-4 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints/ptuning \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM P-Tuning" \--device cuda:0

其中,pre_seq_len 是指在每个层前面添加多少个可学习的前缀 token,该值设置的越大显存占用也会越大。

在我们的实验下,两种方式的效果差异不大:

P-Tuning v.s. LoRA Finetune

模型最终的训练结果如下:

模型训练结果

好啦,以上就是 ChatGLM 的全部内容,感谢观看~

完整源码在这里:

ChatGLM Finetune Code​github.com/HarderThenHarder/transformers_tasks/blo


http://www.ppmy.cn/news/377789.html

相关文章

华硕S400装win7

华硕s400电脑&#xff0c;WIN8系统使用不习惯降级到WIN7。 安装WIN7流程&#xff1a; 开机&#xff0c;按F2&#xff0c;进入bios。 到【Security】&#xff0c;选项【Secure Boot Control】&#xff0c;选择【Disabled】&#xff0c; F10保存并重启 按F2进入bios 到【Boot】…

android 游戏语言设置在哪里设置中文版,使命召唤手游语言变更方法 怎么设置中文...

很多小伙伴们在玩使命召唤手游外服的时候&#xff0c;一打开游戏都是英文的界面&#xff0c;那么这款游戏是怎么设置中文的呢&#xff0c;这里就来和大家分享一下使命召唤手游语言变更设置方法&#xff0c;一起来看看吧。 1、我们进入游戏之后&#xff0c;在主界面的右上方可以…

使命召唤手游服务器维护,使命召唤手游国服登录不了如何解决

使命召唤手游是一款很是刺激的多人射击游戏&#xff0c;玩法十分的吸引人&#xff0c;拥有着超高清的画质&#xff0c;很受喜爱&#xff0c;使命召唤手游国服登录不了怎么解决&#xff1f;现在iefans小编就给大家带来介绍&#xff0c;感兴趣就来看看吧。 一、配置问题 使命召唤…

使命召唤手游服务器显示错误,使命召唤手游无法连接服务器是什么原因

使命召唤手游无法连接服务器是什么原因&#xff0c;相信大家在玩使命召唤手游的过程中&#xff0c;经常会遇到这样的问题&#xff0c;下面ourplay小编就简单为大家介绍几种常见的解决方案。 使命召唤手游游戏简介 《使命召唤手游》是一款由动视和腾讯联合推出的大型多人在线第一…

计算机系统配置有哪些,电脑的配置基本知识 电脑有哪些基本配置

1、CPU&#xff0c;这个主要取决于频率和二级缓存&#xff0c;三级缓存&#xff0c;核心数量。频率越高、二级缓存越大&#xff0c;三级缓存越大&#xff0c;核心越多&#xff0c;运行速度越快。速度越快的CPU只有三级缓存影响响应速度。 2、内存&#xff0c;内存的存取速度取决…

新手玩家一定要学会配枪,使命召唤手游,对枪械是非常专业的

不知道有很多玩家都吃过这样的亏&#xff0c;那就是随便使用一把武器&#xff0c;然后就开始和敌人进行战斗&#xff0c;这样的情况只会让自己更早的被淘汰出去。因此在游戏当中我们是不是可以避免这样的情况呢&#xff1f;其实玩家在游戏当中的体验完全和自己事先的准备有关系…

cf两边黑屏怎么解决win10_使命召唤17黑屏怎么解决 使命召唤17黑屏死机解决方法...

使命召唤17黑屏怎么办?玩家在游戏的时候总是不可避免的出现一些莫名其妙的问题&#xff0c;比如黑屏死机之类的&#xff0c;下面小编给大家介绍使命召唤17黑屏死机解决方法&#xff0c;感兴趣的朋友来了解下哦。 使命召唤17黑屏怎么解决 使命召唤17黑屏死机解决方法 1、首先我…

使命召唤手游ios端终于上线啦:这画质这操作手感我要肝爆它

使命召唤手游如何下载&#xff1f;继上次使命召唤印度国际服测试之后&#xff0c;澳大利亚服体验服也在最近开启了测试&#xff0c;本次测试为安卓、ios双端测试&#xff0c;之前吐槽印度服没有ios的&#xff0c;这次可以去玩澳服了。 1、使命召唤手游体验服如何下载&#xff1…