LLaMA-Factory GLM4-9B-CHAT LoRA 指令微调实战

news/2024/12/29 4:39:06/

LoRA_font_0">🤩LLaMA-Factory GLM LoRA 微调

llamafactory_1">安装llama-factory包

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git

进入下载好的llama-factory,安装依赖包

cd LLaMA-Factory
pip install -e ".[torch,metrics]"
#上面这步操作会完成torch、transformers、datasets等相关依赖包的安装

LLaMAFactory目录结构

大模型下载

modelscope代码方式下载

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat', cache_dir='/root/autodl-tmp', revision='master')

git clone方式下载

git lfs install # 大文件传输
sudo apt-get install git-lfsgit clone https://www.modelscope.cn/ZhipuAI/glm-4-9b-chat.git

指令数据集构建-Alpaca 格式

Alpaca 格式是一种用于训练自然语言处理(NLP)模型的数据集格式,特别是在对话系统和问答系统中。这种格式通常包含指令(instruction)、输入(input)和输出(output)三个部分,它们分别对应模型的提示、模型的输入和模型的预期输出。三者的数据都是字符串形式

Alpaca 格式训练数据集

[{"instruction": "Answer the following question about the movie 'Inception'.","input": "What is the main theme of the movie?","output": "The main theme of the movie 'Inception' is the exploration of dreams and reality."},{"instruction": "Provide a summary of the book '1984' by George Orwell.","input": "What is the book '1984' about?","output": "The book '1984' is a dystopian novel that tells the story of a totalitarian regime and the protagonist's struggle against it."}
]

我的数据集构建如下

crop_train.json

[{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "煤是一种常见的化石燃料,家庭用煤经过了从\"煤球\"到\"蜂窝煤\"的演变。","output": "[{\"head\": \"煤\", \"relation\": \"use\", \"tail\": \"燃料\"}]"},{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "内分泌疾病是指内分泌腺或内分泌组织本身的分泌功能和(或)结构异常时发生的症候群。","output": "[{\"head\": \"腺\", \"relation\": \"use\", \"tail\": \"分泌\"}]"},
]

我的数据格式转换代码:

import json
import re# 选择要格式转换的数据集
file_name = "merged_trainProcess.json"# 读取原始数据
with open(f'./{file_name}', 'r', encoding='utf-8') as file:data = json.load(file)# 转换数据格式
converted_data = [{"instruction": item["instruction"],"input": item["text"],"output": json.dumps(item["triplets"], ensure_ascii=False),} for item in data]# 将转换后的数据写入新文件
output_file_name = f'processed_{file_name}'
with open(output_file_name, 'w', encoding='utf-8') as file:json.dump(converted_data, file, ensure_ascii=False, indent=4)print(f'{output_file_name} Done')

llamafactory_112">构建好后保存到llama-factory目录中某文件下

修改 LLaMa-Factory 目录中的 data/dataset_info.json 文件,在其中添加:

"crop_merged": {"file_name": "/home/featurize/data/crop_train.json"		#自己的训练数据集.json文件的绝对路径}

微调模型代码

在 LLaMA-Factory 目录中**新建配置文件 crop_glm4_lora_sft.yaml :**

### model:glm-4-9b-chat模型地址的绝对路径
model_name_or_path: /home/featurize/glm-4-9b-chat### method
stage: sft			# supervised fine-tuning(监督式微调
do_train: true		# 是否执行训练过程
finetuning_type: lora		# 微调技术的类型
lora_target: all			# all 表示对模型的所有参数进行 LoRA 微调### dataset
# dataset 要和 data/dataset_info.json 中添加的信息保持一致
dataset: crop_merged			# 数据集的名称
template: glm4					# 数据集的模板类型
cutoff_len: 2048				# 输入序列的最大长度
max_samples: 100000				# 最大样本数量, 代表在训练过程中最多只会用到训练集的1000条数据;-1代表训练所有训练数据集
overwrite_cache: true	
preprocessing_num_workers: 16		# 数据预处理时使用的进程数			### output
# output_dir是模型训练过程中的checkpoint,训练日志等的保存目录
output_dir: saves/crop-glm4-epoch10/lora/sft
logging_steps: 10		# 日志记录的频率
#save_steps: 500
plot_loss: true			# 绘制损失曲线
overwrite_output_dir: true		# 是否覆盖输出目录中的现有内容
save_strategy: epoch			# 保存策略,epoch 表示每个 epoch 结束时保存。### train
per_device_train_batch_size: 1		# 每个设备上的批次大小,每增加几乎显存翻倍
gradient_accumulation_steps: 8		# 梯度累积的步数,这里设置为 8,意味着每 8 步执行一次梯度更新。
learning_rate: 1.0e-4				# 学习率
num_train_epochs: 3				# 训练的 epoch 数
lr_scheduler_type: cosine			#学习率调度器的类型
warmup_ratio: 0.1
fp16: true							# 更适合模型推理,混合精度训练,在训练中同时使用FP16和FP32
#bf16: true							# 更适合训练阶段
gradient_checkpointing: true		# 启用梯度检查点(gradient checkpointing),减少中间激活值的显存占用### eval
do_eval: false				# 是否进行评估
val_size: 0.1				# 验证集的大小比例,这里设置为 0.1,即 10%
per_device_eval_batch_size: 1		# 每个设备上的评估批次大小
eval_strategy: steps		# 评估策略,steps 表示根据步数进行评估
eval_steps: 1000				# 评估的步数间隔
### model 
model_name_or_path: /home/featurize/glm-4-9b-chat### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all### dataset
dataset: crop_merged
template: glm4
cutoff_len: 512
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: saves/crop-glm4-epoch10/lora/sft
logging_steps: 100
save_steps: 1000
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 1
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
gradient_checkpointing: true### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: epoch

执行以下命令开始微调

cd LLaMA-Factory
llamafactory-cli train crop_glm4_lora_sft.yaml

合并模型代码

训练完成后,在 LLaMA-Factory 目录中**新建配置文件 crop_glm4_lora_sft_export.yaml:**

导出的模型将被保存在models/CropLLM-glm-4-9b-chat目录下。

### model
model_name_or_path: /home/featurize/glm-4-9b-chat
# 刚才crop_glm4_lora_sft.yaml文件中的 output_dir
adapter_name_or_path: saves/crop-glm4-epoch10/lora/sft
template: glm4
finetuning_type: lora### export
export_dir: models/CropLLM-glm-4-9b-chat
export_size: 2
export_device: cpu
export_legacy_format: false

半精度(FP16、BF16)

执行以下命令开始合并模型:

cd LLaMA-Factory
llamafactory-cli export crop_glm4_lora_sft_export.yaml

models/CropLLM-glm-4-9b-chat 目录中就可以获得经过Lora微调后的完整模型

在这里插入图片描述

推理预测

import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_path = "/home/featurize/LLaMA-Factory/models/CropLLM-glm-4-9b-chat"device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)system_prompt = "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。"
input = "玉米黑粉病又名“乌霉”和“瘤黑粉病”,病原菌是真菌,:担孢子菌,是由于玉米黑粉菌所引起的一种局部浸染性病害。病瘤内的黑粉是病菌的"inputs = tokenizer.apply_chat_template([{"role": "system", "content": system_prompt},{"role": "user", "content": input}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True)inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,trust_remote_code=True
).to(device).eval()gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)outputs = outputs[:, inputs['input_ids'].shape[1]:]print(f"model: {model_path}")print(f"task: {system_prompt}")print(f"input: {input}")output = tokenizer.decode(outputs[0], skip_special_tokens=True)print(f"output: {output}")

http://www.ppmy.cn/news/1558974.html

相关文章

【YOLOv3】源码(train.py)

概述 主要模块分析 参数解析与初始化 功能:解析命令行参数,设置训练配置项目经理制定详细的施工计划和资源分配日志记录与监控 功能:初始化日志记录器,配置监控系统项目经理使用监控和记录工具,实时跟踪施工进度和质量…

HarmonyOS NEXT 实战之元服务:静态案例效果--- 日出日落

背景: 前几篇学习了元服务,后面几期就让我们开发简单的元服务吧,里面丰富的内容大家自己加,本期案例 仅供参考 先上本期效果图 ,里面图片自行替换 效果图1完整代码案例如下: import { authentication } …

swagger,showdoc,apifox,Mock 服务,dubbo,ZooKeeper和dubbo的关系

Swagger、ShowDoc 和 Apifox 之间的区别与优势 Swagger、ShowDoc 和 Apifox 都是用于 API 文档管理和测试的工具,但它们各有特色和适用场景。以下是详细的比较,并附上每个工具的具体用法示例。 1. Swagger 特点与优势: 广泛采用: Swagger…

【玩转OCR】 | 腾讯云智能结构化OCR在多场景的实际应用与体验

文章目录 引言产品简介产品功能产品优势 API调用与场景实践图像增强API调用实例发票API调用实例其他场景 结语相关链接 引言 在数字化信息处理的时代,如何高效、精准地提取和结构化各类文档数据成为了企业和政府部门的重要需求。尤其是在面对海量票据、证件、表单和…

基于Linux内核防火墙-数据包转发和转换详解

大家觉得有意义和参考价值记得关注和点赞!!! 本文技术含量稍微有点偏高,读起来理解来有点难,在具体读懂之前将前面的一些概念了解清楚,不然后续很变扭。 一、概念介绍 Linux系统的防火墙:IP信息…

Java圣诞树

目录 写在前面 技术需求 程序设计 代码分析 一、代码结构与主要功能概述 二、代码功能分解与分析 1. 类与常量定义 2. 绘制树的主逻辑 3. 彩色球的绘制 4. 动态效果的实现 5. 窗口初始化 三、关键特性与优点 四、总结 写在后面 写在前面 Java语言绘制精美圣诞树…

STM32高级物联网通信之以太网通讯

目录 以太网通讯基础知识 什么是以太网 互联网和以太网的区别 1)概念与范围 (1)互联网 (2)以太网 2)技术特点 (1)互联网 (2)以太网 3)应用场景 (1)互联网 (2)以太网 以太网的层次 1)物理层 2)数据链路层 OSI 7层模型 TCPIP 4层模型 一些常见…

机器学习2-NumPy

ndarray自动广播扩展维度,便于进行行列式,数组计算 # 自动广播机制,1维数组和2维数组相加# 二维数组维度 2x5 # array([[ 1, 2, 3, 4, 5], # [ 6, 7, 8, 9, 10]]) d np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) # c是一…