基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

news/2024/12/4 18:26:44/

背景

LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持 vllm 推理,故导致推理速度不够快。

LlamaFactory 目前支持通过 VLLM API 进行部署,调用 API 时的响应速度,仍然没有vllm批量推理的速度快。

如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。

简介

在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:

import json
import os
from typing import Listfrom vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequestfrom llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizerdef vllm_infer():model_args, data_args, training_args, finetuning_args, generating_args = (get_train_args())tokenizer = load_tokenizer(model_args)["tokenizer"]template = get_template_and_fix_tokenizer(tokenizer, data_args)eval_dataset = get_dataset(template, model_args, data_args, training_args, finetuning_args.stage, tokenizer)["eval_dataset"]prompts = [item["input_ids"] for item in eval_dataset]prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)labels = [list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))for item in eval_dataset]labels = tokenizer.batch_decode(labels, skip_special_tokens=True)sampling_params = SamplingParams(temperature=generating_args.temperature,top_k=generating_args.top_k,top_p=generating_args.top_p,max_tokens=2048,)if model_args.adapter_name_or_path:if isinstance(model_args.adapter_name_or_path, list):lora_requests = []for i, _lora_path in enumerate(model_args.adapter_name_or_path):lora_requests.append(LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path))else:lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path)enable_lora = Trueelse:lora_requests = Noneenable_lora = Falsellm = LLM(model=model_args.model_name_or_path,trust_remote_code=True,tokenizer=model_args.model_name_or_path,enable_lora=enable_lora,)outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)if not os.path.exists(training_args.output_dir):os.makedirs(training_args.output_dir, exist_ok=True)output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.jsonl")with open(output_prediction_file, "w", encoding="utf-8") as writer:res: List[str] = []for text, pred, label in zip(prompts, outputs, labels):res.append(json.dumps({"prompt": text, "predict": pred.outputs[0].text, "label": label},ensure_ascii=False,))writer.write("\n".join(res))

vllm.yaml 示例:

## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
# adapter_name_or_path: lora模型### method
stage: sft
do_predict: true
finetuning_type: lora### dataset
dataset_dir: 数据集路径
eval_dataset: 数据集
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: output/
overwrite_output_dir: true### eval
predict_with_generate: true

程序调用:

python vllm_infer.py vllm.yaml

程序运行速度:

Processed prompts: 100%|| 1000/1000 [01:56<00:00,  8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57

总结

本方案在原生 LlamaFactory 数据集的基础上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。

进一步阅读

如果微调模型后,发现使用vllm模型批量效果不太好,可以参考下述文章:

  • 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645

亲测,LLamafactory 部署 模型,然后使用 Async API 调用后评估效果会好一些。


http://www.ppmy.cn/news/1552349.html

相关文章

FreeRTOS之ARM CR5栈结构操作示意图

FreeRTOS之ARM CR5栈结构操作示意图 1 FreeRTOS源码下载地址2 ARM CR5栈结构操作宏和接口2.1 portSAVE_CONTEXT宏2.1.1 portSAVE_CONTEXT源码2.1.2 portSAVE_CONTEXT宏操作栈结构变化示意图 2.2 portRESTORE_CONTEXT宏2.2.1 portRESTORE_CONTEXT源码2.2.2 portRESTORE_CONTEXT宏…

Nginx管理维护运维规范

Nginx管理维护运维规范 一、版本约束 软件的不同版本&#xff0c;在使用起来都有可能带来不可预知的影响&#xff0c;因此需要统一整理&#xff0c;固定下来&#xff0c;不允许轻易变更。 在Nginx开源版官网&#xff0c;点击右侧download可以看到各个版本的Nginx&#xff0c;…

第六届金盾信安杯Web题解

比赛一共4道Web题,比赛时只做出三道,那道文件上传没有做出来,所以这里是另外三道题的WP 分别是 fillllll_put hoverfly ssrf fillllll_put 涉及: 绕过exit() 死亡函数 php://filter 伪协议配合base64加解密 一句话木马 题目源码&#xff1a; $content参数在开头被…

Linux笔试题(自己整理,已做完,选择题)

详细Linux内容查看&#xff1a;day04【入门】Linux系统操作-CSDN博客 1、部分笔试题 本文的笔试题&#xff0c;主要是为了复习学习的day04【入门】Linux系统操作-CSDN博客的相关知识点。后续还会更新一些面试相关的题目。 欢迎一起学习

MongoDB集群分片安装部署手册

文章目录 一、集群规划1.1 集群安装规划1.2 端口规划1.3 目录创建 二、mongodb安装&#xff08;三台均需要操作&#xff09;2.1 下载、解压2.2 配置环境变量 三、mongodb组件配置3.1 配置config server的副本集3.1.1 config配置文件3.1.2 config server启动3.1.3 初始化config …

链表的分类以及双向链表的实现

1.链表的分类 1.1单向链表与双向链表的区别 单向链表只能从头结点遍历到尾结点。 双向链表不仅能从头结点遍历到尾结点&#xff0c;还能从尾结点遍历到头结点。 1.2带头链表与不带头链表的区别 1.3循环链表与不循环链表的区别&#xff08;也称为带环链表与不带环链表&#x…

【嵌入式系统设计】LES3~5:Cortex-M4系统架构(上)第1节 ARM处理器,M4内核处理器,M4调试跟踪接口

关注作者了解更多 我的其他CSDN专栏 过程控制系统 工程测试技术 虚拟仪器技术 可编程控制器 工业现场总线 数字图像处理 智能控制 传感器技术 嵌入式系统 复变函数与积分变换 单片机原理 线性代数 大学物理 热工与工程流体力学 数字信号处理 光电融合集成电路…

LeetCode - #150 逆波兰表达式求值

文章目录 前言1. 描述2. 示例3. 答案关于我们 前言 我们社区陆续会将顾毅&#xff08;Netflix 增长黑客&#xff0c;《iOS 面试之道》作者&#xff0c;ACE 职业健身教练。&#xff09;的 Swift 算法题题解整理为文字版以方便大家学习与阅读。 LeetCode 算法到目前我们已经更新…