论文解读 | EMNLP2024 一种用于大语言模型版本更新的学习率路径切换训练范式

server/2024/12/28 19:57:14/

点击蓝字

ae148d5f2c93be30b5254b52294cc51f.jpeg

关注我们

AI TIME欢迎每一位AI爱好者的加入!

ea0f44fcf04d08bb15a80ac3a5818a28.png

点击 阅读原文 观看作者讲解回放!

作者简介

王志豪,厦门大学博士生

刘诗雨,厦门大学硕士生

内容简介

新数据的不断涌现使版本更新成为大型语言模型(LLMs)不可或缺的需求。LLMs的版本更新训练范式包括从头预训练(PTFS)和继续预训练(CPT)。初步实验表明,PTFS在预训练性能上表现更好,而CPT的训练成本较低。此外,随着版本更新的进行,两种范式的性能和训练成本差距逐渐扩大。为探究这一现象的根本原因,作者分析了学习率对CPT的两个阶段的影响:准备初始化参数(checkpoint)和基于该checkpoint的继续预训练。研究表明,在第一阶段中使用较大学习率以及在第二阶段中使用具有完整率衰减过程的学习率对于LLMs的版本更新至关重要。因此,作者提出了一种基于学习率路径切换的训练范式。该范式包括一条主路径和多条分支路径。LLMs在主路径上以最大学习率进行预训练,而分支路径则基于LLMs在主路径上的checkpoints使用新增数据进行版本更新。广泛的实验表明该范式的有效性和泛化性。特别是在训练4个版本的LLMs时,该范式在保持与PTFS相当的预训练性能同时,将总训练成本减至58%。

论文地址:https://arxiv.org/abs/2410.04103

研究动机

这篇论文的研究动机在于,随着新的预训练数据不断涌现,大型语言模型(LLMs)面临着版本更新的需求,以确保模型能力的持续提升。现有的版本更新范式包括从头预训练(Pre-Training From Scratch,简称PTFS)和继续预训练(Continue Pre-Training,简称CPT)。图中列出了这两种范式在使用cosine学习率调度策略时的学习曲线。可以明显看到,两种范式之间的主要区别在于更新过程中的学习率变化。这启发了作者从学习率角度出发,研究新的预训练范式。

46837a1f390baa4e6ac677414ad14776.png

先导实验

为了比较这两种范式在性能和成本上的差异,作者进行了一个先导实验。实验选择了训练LLMs时最常见的3种学习率调度策略,测试在更新4个版本LLMs情况下的性能和成本差距。需要注意的是,这里的性能用困惑度(Perplexity,PPL)来表示,数值越低代表LLMs性能越好。

从实验结果可以观察到,尽管CPT的版本更新成本远低于PTFS,但PTFS的性能优于CPT,而且这种性能差距随着版本数的增加而增大。

27a38635883de3f9743e441f9313af57.png

性能差距增大的原因

为了研究这种性能差距产生的原因,作者将CPT拆分成两个阶段。第一个阶段是为CPT准备初始参数(checkpoint),第二个阶段是基于初始checkpoint进行继续训练。紧接着,作者基于这两个阶段,分别设计两组实验来探索学习率对CPT两个阶段的影响。

在第一组实验中,作者采用具有不同衰减速度的学习率作为第一阶段的学习率曲线,并固定了第二阶段的学习率曲线。结果表明,当第一阶段的学习率固定为最大值时,初始checkpoint的模型性能最低,但最终性能却是最好的。

在第二组实验中,作者固定了第一阶段的学习率曲线,采用具有不同衰减速度的学习率作为第二阶段的学习率曲线。结果显示,当第二阶段学习率快速衰减到最小值时,对应的LLMs性能最佳。

基于上述两组实验我们可以得出如下结论:1.第一阶段的大学习率和第二阶段完整的学习率衰减过程对CPT的性能尤为重要。2. CPT无法兼顾不同版本LLMs的性能。完整的学习率衰减过程能确保当前版本的LLMs的最优性能,但后续版本的LLMs则需要以大学习率训练提供的初始化checkpoint,这是CPT无法同时满足的。

1b68def433351995a81ff4bf1442a9c3.png

训练范式

为了解决CPT两阶段对不同学习率要求的冲突,作者提出了一种基于学习率路径切换的训练范式。该范式包括一条主路径和多条分支路径。在主路径上,LLMs以最大学习率从头开始预训练,为后续版本更新提供初始化checkpoint。当我们想获得新版的LLMs时,可以直接基于主路径的当前checkpoint继续预训练。在这个过程中,学习率会经历一个完整且快速的衰减过程,从而以较低的成本来保证新版LLMs的性能。同时,在主路径上LLMs仍然使用新增数据对当前checkpoint以最大学习率进行预训练,以便于后续的版本更新。

不同于PTFS和CPT,该范式还包含关键超参数α用于控制分支路径在训练步长中所占的比例。根据版本更新的总时间复杂度计算,该范式与CPT一样,确保了线性的复杂度。

2b49d82e98fe17d50d480c4e1c86caf9.png

关键参数实验

对于本文提出的范式,参数α是一个关键参数。α值越高,模型的性能相对越好,但总的训练成本也会相应增加。根据对不同α值的实验结果,作者选择了α等于0.6作为最终参数。

2f8eab31125e29eb53c116e9da4c3930.png

预训练性能

基于三种学习率调度策略下,作者测试了所提出的预训练范式、PTFS及CPT各自的性能和成本。结果显示,所提出的范式在性能和训练成本上取得了更佳的平衡。在实现与PTFS相当性能的同时,仅需要58%的训练成本。

63064ab2a5a841b32e475f46b5997bbc.png

下游任务性能

在9个常见的下游任务中,经过微调训练后,作者的范式取得了最优的平均性能。

daf35d429da5319c9de64bedc768506e.png

泛化性

模型结构和参数规模

为了验证本文范式在模型结构和参数规模上的泛化性,作者不仅在LLaMA结构上进行了实验,还在Qwen模型结构上进行了实验。此外,作者还在不同参数规模下进行了测试。结果证明,该范式具有良好的泛化性能。

3f894c09340b089dd253c45a0131693a.png

数据规模和最大学习

并且,作者还测试了在不同数据规模和不同最大学习率设置下所提出范式的性能。结果表明,该范式在不同数据规模和学习率设置下表现出良好的性能,进一步验证了其适用性和泛化性。

6de189738fe48367e02da9a73f10b95a.png

实际应用与未来工作

最后,目前该范式已经实际应用于vivo蓝心基础大模型的研发中。对于未来的工作,作者计划将视角聚焦于大模型版本更新过程中可能存在的更多待解决场景。例如,在更新过程中,同时伴随模型参数规模的扩展、模型结构的调整,以及版本更新在监督微调(SFT)或对齐(alignment)阶段中的应用。

2ad888782114d8fa25cbf6abd6ce91c7.png

本期文章由陈研整理

往期精彩文章推荐

446094ce859ca62d307e584062d284d6.jpeg

迈向AGI——大模型创新体验嘉年华邀请函

 关于AI TIME 

AI TIME源起于2019年,旨在发扬科学思辨精神,邀请各界人士对人工智能理论、算法和场景应用的本质问题进行探索,加强思想碰撞,链接全球AI学者、行业专家和爱好者,希望以辩论的形式,探讨人工智能和人类未来之间的矛盾,探索人工智能领域的未来。

迄今为止,AI TIME已经邀请了2000多位海内外讲者,举办了逾700场活动,超800万人次观看。

 41d666e1361251598d880c7ab65fa5f3.png

我知道你 

在看

提出观点,表达想法,欢迎 

留言

b2e11938e64126deb46fb83a46b44eb6.gif

点击 阅读原文 观看作者讲解回放!


http://www.ppmy.cn/server/154007.html

相关文章

大数据 深度学习毕设课题帮助

文章目录 🚩 1 前言1.1 选题注意事项1.1.1 难度怎么把控?1.1.2 题目名称怎么取? 1.2 开题选题推荐1.2.1 起因1.2.2 核心- 如何避坑(重中之重)1.2.3 怎么办呢? 🚩2 选题概览🚩 3 项目概览题目1 : 大数据电商…

一个简单的深度学习模型例程,使用Keras(基于TensorFlow)构建一个卷积神经网络(CNN)来分类MNIST手写数字数据集。

下面是一个简单的深度学习模型例程,使用Keras(基于TensorFlow)构建一个卷积神经网络(CNN)来分类MNIST手写数字数据集。例程包括详细的代码和说明。 1. 安装所需库 首先,确保你已经安装了tensorflow&#…

K8S--“ Failed to create pod sandbox: nameserver list is empty“

原因是因为宿主机的/etc/resolv.conf 文件 有残缺, 填写一半,这个问题 cat /etc/resolv.conf填写好后,重启pod或等待一下再查看即可

应对TensorFlow导入Keras时发生的错误问题

在机器学习和深度学习领域,TensorFlow和Keras是两个非常流行的框架。TensorFlow是一个开源的机器学习库,由Google开发,用于设计、构建和训练深度学习模型。而Keras则是一个高层的神经网络API,它能够以TensorFlow等底层框架为基础&…

一篇文章了解 Kafka

文章目录 Kafka 简介什么是 KafkaKafka 的主要特性Kafka 的核心使用场景Kafka 在消息队列领域的地位与优势 Kafka 的架构设计Kafka 的核心组件BrokerProducerConsumerZookeeper/ Kafka Raft (KRaft)Topic 和 Partition 分布式架构设计Leader-Follower 模型分区与副本机制 消息存…

06 - Django 视图view

HttpRequest 和 HttpResponse Django中的视图主要用来接受Web请求,并做出响应。 视图的本质就是一个Python中的函数 视图的响应分为两大类 以Json数据形式返回(JsonResponse)以网页的形式返回 重定向到另一个网页 (HttpResponseRedirect)错误视图(4XX,5XX) (Htt…

MySQL并发问题区别-MVCC如何解决的

脏读 事务a,事务b,b读到了a刚修改未提交的数据 不可重复读 针对同一行记录,两次读到的结果不一致 (范围是一行) 幻读 范围比不可重复读大很多,是表的范围,事务a第一次查的时候不存在&#…

zabbix5.0版本(安装部署+添加服务器+拆分数据库)

目录 1.监控内容 2.监控工具 3.Zabbix安装 4.Zabbix添加监控服务器 5.拆分数据库 本篇文章介绍zabbix监控,监控是对我们操作系统进行不间断的监控,这是软件生命周期非常重要的一环,可以做到事前告警,事后根据监控内容排查问题…