TRL - Transformer Reinforcement Learning(基于Transformer的强化学习)

ops/2025/2/11 9:06:40/

TRL - Transformer Reinforcement Learning(基于Transformer的强化学习

flyfish

地址:https://github.com/huggingface/trl

TRL是一个用于对基础模型进行后训练的全面库,是一个前沿的库,专门用于使用先进的技术(如监督微调(SFT)、近端策略优化(PPO)和直接偏好优化(DPO))对基础模型进行后训练。

利用Accelerate,通过DDP和DeepSpeed等方法,从单个GPU扩展到多节点集群。与PEFT完全集成,通过量化和LoRA/QLoRA,即使在硬件资源有限的情况下也能训练大型模型。集成Unsloth,通过优化的内核加速训练。

通过SFTTrainer、DPOTrainer、RewardTrainer、ORPOTrainer等训练器,轻松访问各种微调方法。使用预定义的模型类(如AutoModelForCausalLMWithValueHead)简化与LLMs的强化学习(RL)

如何使用 TRL(Transformers Reinforcement Learning)库

1. SFTTrainer - 监督微调训练器

  • 用途:用于在自定义数据集上进行监督微调(Supervised Fine-Tuning, SFT)。
  • 使用方法
    1. 加载你的训练数据集,例如 "trl-lib/Capybara"
    2. 配置训练参数,包括输出目录等。
    3. 初始化 SFTTrainer 并传入模型、训练数据集等参数。
    4. 调用 .train() 方法开始训练。
from trl import SFTConfig, SFTTrainer
from datasets import load_datasetdataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,
)
trainer.train()

2. RewardTrainer - 奖励模型训练器

  • 用途:用于训练奖励模型,该模型可以评估文本生成的质量,通常用于强化学习框架中。
  • 使用方法
    1. 加载预训练的模型和分词器。
    2. 加载适合于奖励模型的数据集。
    3. 配置训练参数。
    4. 初始化 RewardTrainer 并传入必要组件。
    5. 开始训练。
from trl import RewardConfig, RewardTrainer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_datasettokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", num_labels=1)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(args=training_args,model=model,processing_class=tokenizer,train_dataset=dataset,
)
trainer.train()

3. GRPOTrainer - 组相对策略优化训练器

  • 用途:基于组相对策略优化算法,适用于更高效的内存利用场景。
  • 使用方法:与上述类似,但需要定义一个奖励函数。
from trl import GRPOConfig, GRPOTrainer
from datasets import load_datasetdef reward_len(completions, **kwargs):return [-abs(20 - len(completion)) for completion in completions]dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(model="Qwen/Qwen2-0.5B-Instruct",reward_funcs=reward_len,args=training_args,train_dataset=dataset,
)
trainer.train()

4. DPOTrainer - 直接偏好优化训练器

  • 用途:用于根据人类偏好直接优化语言模型。
  • 使用方法:加载模型和分词器,配置训练参数,并初始化 DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
from datasets import load_datasetmodel = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()

安装

pip install trl

根据自己的需求定制

git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]

-e 参数:表示“editable”模式安装。这意味着可以在开发过程中直接从源代码目录运行程序,而不需要每次修改后重新安装。这对于开发阶段非常有用,因为它允许在原地编辑代码,并立即看到更改的效果。

. (点):这表示当前目录是需要被安装的Python包或项目。也就是说,pip会在这个目录下寻找setup.py文件(或pyproject.toml等),并根据里面定义的信息进行安装。

[dev]:这部分指定了一个额外的依赖集,通常在setup.py文件中的extras_require字段中定义。[dev]意味着除了基本依赖外,还将安装标记为开发所需的额外依赖项


http://www.ppmy.cn/ops/157484.html

相关文章

使用deepseek快速创作ppt

目录 1.在DeekSeek生成PPT脚本2.打开Kimi3.最终效果 DeepSeek作为目前最强大模型,其推理能力炸裂,但是DeepSeek官方没有提供生成PPT功能,如果让DeepSeek做PPT呢? 有个途径:在DeepSeek让其深度思考做出PPT脚本&#xf…

【DeepSeek】Deepseek辅组编程-通过卫星轨道计算终端距离、相对速度和多普勒频移

引言 笔者在前面的文章中,介绍了基于卫星轨道参数如何计算终端和卫星的距离,相对速度和多普勒频移。 【一文读懂】卫星轨道的轨道参数(六根数)和位置速度矢量转换及其在终端距离、相对速度和多普勒频移计算中的应用 Matlab程序 …

【R】Dijkstra算法求最短路径

使用R语言实现Dijkstra算法求最短路径 求点2、3、4、5、6、7到点1的最短距离和路径 1.设置data,存放有向图信息 data中每个点所在的行序号为起始点序号,列为终点序号。 比如:值4的坐标为(1,2)即点1到点2距离为4;值8的坐标为(6,7)…

【万字详细教程】Linux to go——装在移动硬盘里的Linux系统(Ubuntu22.04)制作流程;一口气解决系统安装引导文件迁移显卡驱动安装等问题

Linux to go制作流程 0.写在前面 关于教程Why Linux to go?实际效果 1.准备工具2.制作步骤 下载系统镜像硬盘分区准备启动U盘安装系统重启完成驱动安装将系统启动引导程序迁移到移动硬盘上 3.可能出现的问题 3.1.U盘引导系统安装时出现崩溃3.2.不影响硬盘里本身已有…

桥接模式——C++实现

目录 1 桥接模式定义 2 桥接模式小例子 3 桥接模式的总结 1 桥接模式定义 首先来看看桥接模式的官方定义吧: 桥接模式是将抽象部分与它的实现部分分离,使它们都可以独立地变化。它是一种结构型模式。 桥接模式的定义确实比较难理解,比较抽…

网络安全架构分层 网络安全组织架构

1.1.4 网络安全系统的基本组成 上节介绍到了,网络安全系统是一个相对完整的安全保障体系。那么这些安全保障措施具体包括哪些,又如何体现呢?这可以从OSI/RM的7层网络结构来一一分析。因为计算机的网络通信,都离不开OSIR/RM的这7层…

【图片合并转换PDF】如何将每个文件夹下的图片转化成PDF并合并成一个文件?下面基于C++的方式教你实现

医院在为患者进行诊断和治疗过程中,会产生大量的医学影像图片,如 X 光片、CT 扫描图、MRI 图像等。这些图片通常会按照检查时间或者检查项目存放在不同的文件夹中。为了方便医生查阅和患者病历的长期保存,需要将每个患者文件夹下的图片合并成…

【系统架构设计师】嵌入式系统之JTAG接口

目录 1. 说明2. 主要功能2.1 硬件调试2.2 边界扫描测试2.3 系统内编程(ISP)2.4 配置和重新配置2.5 实时监控 3. 核心组件和引脚4. 应用场景5. 使用注意事项6. 例题6.1 例题1 1. 说明 1.嵌入式系统中的JTAG(Joint Test Action Group&#xff…