TRL - Transformer Reinforcement Learning(基于Transformer的强化学习)
flyfish
地址:https://github.com/huggingface/trl
TRL是一个用于对基础模型进行后训练的全面库,是一个前沿的库,专门用于使用先进的技术(如监督微调(SFT)、近端策略优化(PPO)和直接偏好优化(DPO))对基础模型进行后训练。
利用Accelerate,通过DDP和DeepSpeed等方法,从单个GPU扩展到多节点集群。与PEFT完全集成,通过量化和LoRA/QLoRA,即使在硬件资源有限的情况下也能训练大型模型。集成Unsloth,通过优化的内核加速训练。
通过SFTTrainer、DPOTrainer、RewardTrainer、ORPOTrainer等训练器,轻松访问各种微调方法。使用预定义的模型类(如AutoModelForCausalLMWithValueHead)简化与LLMs的强化学习(RL)
如何使用 TRL(Transformers Reinforcement Learning)库
1. SFTTrainer
- 监督微调训练器
- 用途:用于在自定义数据集上进行监督微调(Supervised Fine-Tuning, SFT)。
- 使用方法:
- 加载你的训练数据集,例如
"trl-lib/Capybara"
。 - 配置训练参数,包括输出目录等。
- 初始化
SFTTrainer
并传入模型、训练数据集等参数。 - 调用
.train()
方法开始训练。
- 加载你的训练数据集,例如
from trl import SFTConfig, SFTTrainer
from datasets import load_datasetdataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,
)
trainer.train()
2. RewardTrainer
- 奖励模型训练器
- 用途:用于训练奖励模型,该模型可以评估文本生成的质量,通常用于强化学习框架中。
- 使用方法:
- 加载预训练的模型和分词器。
- 加载适合于奖励模型的数据集。
- 配置训练参数。
- 初始化
RewardTrainer
并传入必要组件。 - 开始训练。
from trl import RewardConfig, RewardTrainer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_datasettokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", num_labels=1)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(args=training_args,model=model,processing_class=tokenizer,train_dataset=dataset,
)
trainer.train()
3. GRPOTrainer
- 组相对策略优化训练器
- 用途:基于组相对策略优化算法,适用于更高效的内存利用场景。
- 使用方法:与上述类似,但需要定义一个奖励函数。
from trl import GRPOConfig, GRPOTrainer
from datasets import load_datasetdef reward_len(completions, **kwargs):return [-abs(20 - len(completion)) for completion in completions]dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(model="Qwen/Qwen2-0.5B-Instruct",reward_funcs=reward_len,args=training_args,train_dataset=dataset,
)
trainer.train()
4. DPOTrainer
- 直接偏好优化训练器
- 用途:用于根据人类偏好直接优化语言模型。
- 使用方法:加载模型和分词器,配置训练参数,并初始化
DPOTrainer
。
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
from datasets import load_datasetmodel = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
安装
pip install trl
根据自己的需求定制
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
-e 参数:表示“editable”模式安装。这意味着可以在开发过程中直接从源代码目录运行程序,而不需要每次修改后重新安装。这对于开发阶段非常有用,因为它允许在原地编辑代码,并立即看到更改的效果。
. (点):这表示当前目录是需要被安装的Python包或项目。也就是说,pip会在这个目录下寻找setup.py文件(或pyproject.toml等),并根据里面定义的信息进行安装。
[dev]:这部分指定了一个额外的依赖集,通常在setup.py文件中的extras_require字段中定义。[dev]意味着除了基本依赖外,还将安装标记为开发所需的额外依赖项