GPT3 SFT微调中文1.3B参数量文本生成模型

news/2025/1/13 10:52:53/

本模型在中文GPT-3 1.3B预训练模型的基础上,通过有监督的sft数据训练得到,具备更强的通用生成能力,对话能力等。目前模型可以支持单轮对话,多轮对话,知识增强等不同输入模式。

模型描述

GPT-3模型使用Transformer的Decoder结构,并对Transformer Decoder进行了一些改动,原本的Decoder包含了两个 Multi-Head Attention 结构,GPT-3只保留了 Mask Multi-Head Attention,利用常规的语言建模优化,从左到右的自回归预训练。本模型是基于GPT-3的代码结合大量中文无监督数据和下游任务数据预训练得到,我们训练了多种不同参数的模型 ,GPT-3模型介绍,详见:Language Models are Few-Shot Learners

本项目我们复现了一系列不同规模的中文GPT3模型,包括base/large/1.3B/2.7B/13B/30B/175B等,以及推出基于不同规模模型训练得到的sft微调版本模型,本模型是其中1.3B的版本。全部版本如下表所示:

ModelLayersHeadsd_modelLRBatch
base12127686.0e-40.5M
large241610243.0e-40.5M
1.3B243220482.0e-42M
2.7B323225601.6e-42M
13B404051201.0e-46M
30B485671681.0e-46M
175B(work in process)9696122881.2e-46M

期望模型使用方式以及适用范围

本模型可直接用于文本生成,也可以通过finetune用于各类文本理解的任务。用户可以自行尝试各种输入文档。具体调用方式请参考代码示例。

如何使用

在安装完成ModelScope library之后即可使用GPT-3的text-generation的能力。目前我们免费提供试用的Notebook环境,使用的是单卡GPU,由于显存限制仅可以运行pipeline推理,如果用户在使用Notebook环境时想要运行训练,推荐使用更小规模的large/base版本

代码范例

 
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasksif __name__ == '__main__':model_id = 'damo/nlp_gpt3_sft_text-generation_1.3B'pipe = pipeline(Tasks.text_generation, model=model_id, model_revision='v1.0.2')# 可以在 pipe 中输入 max_length, top_k, top_p, temperature 等生成参数# 不包含对话历史输入input = '当前输入:以AIGC技术未来有广阔的应用场景为题写一篇文章\n当前输出:'print(pipe(input, max_length=1024, top_k=10, temperature=0.7, top_p=0.0))# 包含对话历史输入,支持3轮历史input = '输入:你叫什么名字?\n输出:我的名字是AliceMind</s>\n输入:你能做什么\n输出:我能写作文,写小说等一系列生成能力</s>\n当前输入:以AIGC技术未来有广阔的应用场景为题写一篇文章\n当前输出:'print(pipe(input, max_length=1024, top_k=10, temperature=0.7, top_p=0.0))# 支持增加检索知识提升性能know_list = ['《狂飙》由徐纪周执导的。《狂飙》的导演徐纪周也是编剧之一,代表作品有《永不磨灭的番号》《特战荣耀》《心理罪之城市之光》《杀虎口》《胭脂》等','《狂飙》(The Knockout)是一部由 张译、张颂文、李一桐、张志坚 领衔主演,韩童生 特邀主演,吴健、郝平 友情出演,高叶、贾冰、李健 主演,徐纪周 执导,朱俊懿、徐纪周 担任总编剧的 刑侦',' 狂飙是由中央政法委宣传教育局,中央政法委政法信息中心指导,爱奇艺,留白影视出品,徐纪周执导,张译,李一桐,张志坚领衔主演的刑侦剧。不是。是徐纪周,1976年12月19日出生,毕业于中央戏剧']input = '检索知识:' + '\n'.join(know_list) + '\n当前输入:狂飙的导演是谁\n当前输出:'print(pipe(input, max_length=1024, top_k=10, temperature=0.7, top_p=0.0))

模型局限性以及可能的偏差

模型训练数据来源于网络,生成结果可能存在一定偏差。

训练数据介绍

训练数据包括中文维基百科、网络上公开文本数据。

模型训练流程

本模型的预训练分为两个阶段。第一阶段严格按照原始GPT3的参数设置进行:在中文wiki/ Common crawl等无监督数据上,通过自回归的训练任务训练了约300B字得到。第二阶段中,我们加入了多种有监督数据继续训练,使得模型具备多种任务的zero-shot的处理能力。

我们为GPT3模型支持了续写训练与输入输出形式的训练,训练方式不需要额外指定,训练数据集仅包含 src_txt 时会进行续写训练,同时包含 src_txt 和 tgt_txt 时会进行输入输出形式的训练。以下将为两种训练方式提供示例代码。

训练准备(重要)

目前对于GPT3 1.3B/2.7B 两个模型我们在训练阶段支持了运行时的模型自动拆分功能,因此不需要使用与训练tensor并行度匹配的checkpoint即可开始训练。需要注意的是,当使用并行训练时,务必提前在configuration.json中确认以下参数配置正确:

"megatron": {"checkpoint_tensor_model_parallel_size": 1, # 对应checkpoint的并行片数,在1.3B/2.7B模型中为1"world_size": 1, # 全局的并行进程数"tensor_model_parallel_size": 1 # tensor 并行度
}

以单机8卡训练,2的数据并行度和4的tensor并行度(2dp+4tp)为例:

# 训练启动命令
torchrun --nproc_per_node 8 finetune_dureader.py # 这里的8是启动进程数,为dp*tp的值(2*4=8),单机训练时对应配置文件中的world_size
# 无需配置数据并行度,会根据`world_size/tensor_model_parallel_size`计算
# 此时的正确配置
"megatron": {"checkpoint_tensor_model_parallel_size": 1, # 使用modelscope上传的checkpoint时无需修改"world_size": 8, # 此处对应上文的启动进程数nproc_per_node,如果使用其他方式启动多进程训练同理"tensor_model_parallel_size": 2 # 训练使用的 tensor 并行度
}

输入输出形式训练

下面是基于GPT-3中文1.3B模型在Dureader问题生成数据集上二次开发训练

# finetune_dureader.py
from torch.utils.tensorboard import SummaryWriter
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.metainfo import Trainersdataset_dict = MsDataset.load('DuReader_robust-QG')train_dataset = dataset_dict['train'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \.map(lambda example: {'src_txt': example['src_txt'] + '\n'})
eval_dataset = dataset_dict['validation'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \.map(lambda example: {'src_txt': example['src_txt'] + '\n'})max_epochs = 10tmp_dir = './gpt3_dureader'num_warmup_steps = 200def noam_lambda(current_step: int):current_step += 1return min(current_step**(-0.5),current_step * num_warmup_steps**(-1.5))def cfg_modify_fn(cfg):cfg.train.lr_scheduler = {'type': 'LambdaLR','lr_lambda': noam_lambda,'options': {'by_epoch': False}}cfg.train.optimizer = {'type': 'AdamW', 'lr': 1e-4}cfg.train.dataloader = {'batch_size_per_gpu': 4,'workers_per_gpu': 1}cfg.train.hooks.append({'type': 'MegatronHook'})cfg.preprocessor.sequence_length = 1024cfg.model.checkpoint_model_parallel_size = 1return cfgkwargs = dict(model='damo/nlp_gpt3_sft_text-generation_1.3B',train_dataset=train_dataset,eval_dataset=eval_dataset,max_epochs=max_epochs,work_dir=tmp_dir,cfg_modify_fn=cfg_modify_fn)trainer = build_trainer(name=Trainers.gpt3_trainer, default_args=kwargs)
trainer.train()

以上为数据并行度为1的训练脚本,我们推荐使用 torchrun 拉起训练

单机单卡或单机多卡运行时,可以通过以下命令运行训练:

# N 为模型并行度
torchrun --nproc_per_node $N finetune_dureader.py

需要注意,目前1.3B参数量的GPT-3模型训练至少需要32G显存的gpu(如V100)才能进行单卡训练,或至少需要两张16G显存的gpu进行张量并行训练

推理加速

我们对大规模生成模型的推理速度进行了极致优化,13B模型128字的文本生成可以在1秒左右完成。


http://www.ppmy.cn/news/378431.html

相关文章

C# WinForm 自动化必备类库

自动化一定离不开异步执行&#xff0c;异步的概念相对于同步&#xff0c;异步的实现方式是多线程。多线程就是通过CPU分配时间段来执行每个线程。 1.BeginInvoke 应用于winform上修改UI /* Control.BeginInvoke 方法定义 命名空间:System.Windows.Forms 程序集:System.Win…

PB串口王

PB串口王 Function ulong OpenPort(ulong HWND,uint uMsg,ulong dwPort,ulong dwBaud,char byByteSize, char byParity, char byStopBits) Library "comIO.dll" subroutine ClosePort(ulong dwHandle) Library "comIO.dll" subroutine StopIO() Libr…

RS/6000液晶显示屏上显示代码(LED)的含义

RS/6000液晶显示屏上显示代码(LED)的含义 本文介绍RS/6000启动过程中机器上的液晶显示屏代码的含义。本文代码不针对具体机型。 ---------- Dump Progress Indicator ---------- 0c0 The dump completed successfully0c1 The dump failed due to an I/O error.0c2 A user-r…

AIX常见问题整理

AIX常见问题整理 创建时间&#xff1a;2002-08-17 文章属性&#xff1a;原创 文章来源&#xff1a; www.cnsafe.net 文章提交&#xff1a; mayi (mayi99_at_263.net) by:ciline 来自&#xff1a; www.cnsafe.net 提纲&#xff1a; 用feprom_update升级Firmware 2002-0…

python浅学笔记9-IO编程

学习文档 from https://www.liaoxuefeng.com IO编程文件读写读文件file-like object二进制文件字符编码写文件 StringIO和BytesIOStringIOBytesIO 操作文件和目录操作系统环境变量文件和目录目录文件 序列化 picklingJSON IO编程 name:Input/Output face:磁盘&#xff0c;网络…

基于Java+Swing+Mysql实现通讯录管理系统

基于JavaSwingMysql实现通讯录管理系统 一、系统介绍二、功能展示1.用户登陆2.查询信息3.新增信息4.修改信息5.删除信息 三、数据库四、其他系统实现五、获取源码 一、系统介绍 1.登录系统 2.查询信息 3.新增信息 4.修改信息 5.删除信息 运行环境&#xff1a;idea/eclipse、m…

[转载]AIX 常见问题整理

怎样在AIX 5.1中建立热后备(hot spare)磁盘&#xff1f; 环境 AIX 5.1 问题 怎样在AIX 5.1中建立热后备(hot spare)磁盘&#xff1f; 解答 在AIX 5.1中可以在操作系统的级别上建立hot spare磁盘。 如需要在某一卷组(VG)中建立hot spare磁盘&#xff0c;必须满足如下条件&…

java代码读写者问题_夯实Java基础系列16:一文读懂Java IO流和常见面试题

本系列文章将整理到我在GitHub上的《Java面试指南》仓库&#xff0c;更多精彩内容请到我的仓库里查看 https://github.com/h2pl/Java-Tutorial 喜欢的话麻烦点下Star哈 文章首发于我的个人博客&#xff1a; www.how2playlife.com 本文参考 并发编程网 – ifeve.com IO概述 在这…