GLM4 PyTorch模型微调最佳实践

embedded/2024/11/21 23:04:25/

一 引言

2024年6月,智谱AI发布的GLM-4-9B系列开源模型,在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B和GLM-4-9B-Chat均表现出超越Llama-3-8B的卓越性能。并且,本代模型新增对26种语言的支持,涵盖日语、韩语、德语等。除此之外,智谱AI还推出了支持1M上下文长度的GLM-4-9B-Chat-1M模型和基于GLM-4-9B的多模态模型。以下为GLM-4-9B系列模型的具体评测结果。

  • 对话模型典型任务

在这里插入图片描述

  • 基座模型典型任务
    在这里插入图片描述

由于GLM-4-9B在预训练过程中加入了部分数学、推理和代码相关的instruction数据,所以将Llama-3-8B-Instruct也列入比较范围。

  • 长文本

在1M的上下文长度下进行大海捞针实验,结果如下:
在这里插入图片描述

在LongBench-Chat上对长文本能力进行了进一步评测,结果如下:
在这里插入图片描述

二 环境准备

2.1 安装Ascend CANN Toolkit和Kernels

安装方法请参考安装教程或使用以下命令:

# 请替换URL为CANN版本和设备型号对应的URL
# 安装CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install# 安装CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh

2.2 安装openMind Hub Client和openMind Library

  • 安装openMind Hub Client
pip install openmind_hub
  • 安装openMind Library,并安装PyTorch框架及其依赖。
pip install openmind[pt]

更详细的安装信息请参考魔乐社区的环境安装章节。

2.3 安装LLaMa Factory

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch-npu,metrics]"

三 模型链接和下载

GLM-4-9B模型系列由社区开发者在魔乐社区贡献,包括:

  • GLM-4-9B:https://modelers.cn/models/AI-Research/glm-4-9b

  • GLM-4-9B-Chat:https://modelers.cn/models/AI-Research/glm-4-9b

  • GLM-4-9B-Chat-1m:https://modelers.cn/models/AI-Research/glm-4-9b-chat-1m

通过Git从魔乐社区下载模型的repo,以GLM-4-9B-Chat为例:


# 首先保证已安装git-lfs(https://git-lfs.com)
git lfs install
git clone https://modelers.cn/AI-Research/glm-4-9b-chat.git

四 模型推理

用户可以使用openMind Library或者LLaMa Factory进行模型推理,以GLM-4-9B-Chat为例,具体如下:

  • 使用openMind Library进行模型推理

新建推理脚本inference_glm4_9b_chat.py,推理脚本内容为:

python">import torch
from openmind import AutoModelForCausalLM, AutoTokenizerdevice = "npu"# 若模型已下载,可替换成模型本地路径
tokenizer = AutoTokenizer.from_pretrained("AI-Research/glm-4-9b-chat", trust_remote_code=True)query = "你好"inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True)inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained("AI-Research/glm-4-9b-chat",torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,trust_remote_code=True
).to(device).eval()gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)outputs = outputs[:, inputs['input_ids'].shape[1]:]print(tokenizer.decode(outputs[0], skip_special_tokens=True))

执行推理脚本:

python">python inference_glm4_9b_chat.py

推理结果如下:

在这里插入图片描述

五 模型微调

我们使用单张昇腾NPU,基于LLaMa Factory框架,采用广告文案生成数据集进行Lora微调,让模型能够根据用户输入的商品关键字生成对应的广告文案。

5.1 数据集

广告文案数据集(AdvertiseGen)任务为根据输入(content)生成一段广告词(summary),分为训练集和验证集。其中,训练集大小为114K,验证集大小为1K。每个样本有content和summary两个键,分别保存商品关键字和商品文案。
以下是部分示例:

python">{"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳","summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
  • 下载AdvertiseGen数据集
    感谢社区开发者在魔乐社区贡献的AdvertiseGen数据集,使用Git将数据集下载至本地。
python">
git lfs install
git clone https://modelers.cn/AI-Research/AdvertiseGen.git
  • 数据预处理
    下载完成后,需要将train.json和dev.json两个文件的数据处理成alpaca数据格式。因此,创建preprocess_adv_gen.py脚本,脚本内容具体如下:
python">import json
import argparse
import os
import statDEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT
DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSRdef parse_args():parse = argparse.ArgumentParser()parse.add_argument("--data_path", type=str)parse.add_argument("--save_path", type=str)args = parse.parse_args()return argsdef read_data(data_path):data = []with open(data_path, "r", encoding="utf-8") as f:lines = f.readlines()for line in lines:data.append(json.loads(line))return datadef convert_to_alpaca_format(data):results = []for sample in data:example = {}example["instruction"] = sample["content"]example["output"] = sample["summary"]results.append(example)return resultsdef save_data(data, save_path):with os.fdopen(os.open(save_path, DEFAULT_FLAGS, DEFAULT_MODES), "w", encoding="utf-8") as f:json.dump(data, f, indent=4, ensure_ascii=False)if __name__ == "__main__":args = parse_args()data = read_data(args.data_path)data = convert_to_alpaca_format(data)save_data(data, args.save_path)

通过以下命令执行脚本,将数据预处理的结果分别存为adv_gen_train.jsonadv_gen_dev.json

python"># xxx为train,json和dev.json文件路径
python preprocess_adv_gen.py --data_path xxx --save_path ./adv_gen_train.json
python preprocess_adv_gen.py --data_path xxx --save_path ./adv_gen_dev.json

修改LLaMa Factory下的data/dataset_info.json文件,添加数据集描述:


"adv_gen_train": {"file_name": "xxx", // 填写预处理完成的adv_gen_train.json文件路径"columns": {"prompt": "instruction","response": "output"}
},
"adv_gen_dev": {"file_name": "xxx", // 填写预处理完成的adv_gen_dev.json文件路径"columns": {"prompt": "instruction","response": "output"}
},

以上为整个数据预处理流程,在配置文件中使用dataset: adv_gen_train, adv_gen_dev配置即可在微调中使用广告文案生成数据集。

5.2 微调

在LLaMa Factory路径下新建examples/train_lora/glm4_9b_chat_lora_sft.yaml微调配置文件,微调配置文件如下:

### model
model_name_or_path: xxx # 当前仅支持本地加载,填写GLM-4-9B-Chat本地权重路径### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 8
lora_alpha: 32
lora_dropout: 0.1### dataset
dataset: adv_gen_train
template: glm4
cutoff_len: 256
preprocessing_num_workers: 16### output
output_dir: saves/glm4_9b_chat/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 16
gradient_accumulation_steps: 1
learning_rate: 5.0e-4
max_steps: 1000
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

通过下面的命令启动微调:

export ASCEND_RT_VISIBLE_DEVICES=0
llamafactory-cli train examples/train_lora/glm4_9b_chat_lora_sft.yaml

5.3 微调可视化

在这里插入图片描述

5.4 微调结果

  • 评估

训练结束后,通过LLaMa Factory使用微调完成的权重在·adv_gen_dev.json·数据集上预测BLEU和ROUGE分数。在LLaMa Factory路径下新建·examples/train_lora/glm4_9b_chat_lora_predict.yaml·推理配置文件,配置文件内容如下:


### model
model_name_or_path: xxx # 当前仅支持本地加载,填写GLM-4-9B-Chat本地权重路径
adapter_name_or_path: saves/glm4_9b_chat/lora/sft/checkpoint-1000/### method
stage: sft
do_predict: true
finetuning_type: lora### dataset
eval_dataset: adv_gen_dev
template: glm4
cutoff_len: 256
preprocessing_num_workers: 16### output
output_dir: saves/glm4_9b_chat/lora/predict
overwrite_output_dir: true### eval
per_device_eval_batch_size: 128
predict_with_generate: true

通过下面的命令启动评估:

export ASCEND_RT_VISIBLE_DEVICES=0
llamafactory-cli train examples/train_lora/glm4_9b_chat_lora_predict.yaml

评估的结果为:
在这里插入图片描述

  • 推理

微调结束后,在LLaMa Factory路径下新建examples/inference/glm4_9b_chat_lora_sft.yaml推理配置文件,配置文件内容为:

model_name_or_path: xxx # 当前仅支持本地加载,填写GLM-4-9B-Chat本地权重路径
adapter_name_or_path: saves/glm4_9b_chat/lora/sft/checkpoint-1000/
template: glm4
finetuning_type: lora

通过下面的命令启动推理:

llamafactory-cli chat examples/inference/glm4_9b_chat_lora_sft.yaml
  • 训练前推理结果为:

    问题:类型#上衣材质#牛仔布颜色#白色风格#简约图案#刺绣衣样式#外套衣款式#破洞

    在这里插入图片描述

  • 训练后推理结果为:

    • 问题1:类型#上衣材质#牛仔布颜色#白色风格#简约图案#刺绣衣样式#外套衣款式#破洞

      在这里插入图片描述

    • 问题2:类型#裤风格#英伦风格#简约

      在这里插入图片描述

    • 问题3:类型#裙裙下摆#弧形裙腰型#高腰裙长#半身裙裙款式#不规则*裙款式#收腰

      在这里插入图片描述

六 总结

本次实践是在魔乐社区进行。朋友们可以试试,也欢迎分享你们的经验,一起交流:https://modelers.cn

如您在体验过程中遇到任何问题,欢迎访问魔乐社区的帮助中心(https://gitee.com/modelers/feedback),与其他用户交流和寻求支持。


http://www.ppmy.cn/embedded/139455.html

相关文章

Vue跨域资源共享

在Vue前端开发中,跨域问题是一个常见的挑战,特别是当你需要从前端应用向不同域名或端口的后端API发送请求时。跨域请求通常会被浏览器的同源策略(Same-Origin Policy)阻止,以确保安全性。 以下是一些解决Vue前端跨域问…

5. langgraph中的react agent使用 (从零构建一个react agent)

1. 定义 Agent 状态 首先,我们需要定义 Agent 的状态,这包括 Agent 所持有的消息。 from typing import (Annotated,Sequence,TypedDict, ) from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messagesclass …

【人工智能】用Python构建词向量模型:从零实现Word2Vec并探索FastText在低频词上的优势

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 词向量是自然语言处理中的关键技术之一,将词语转换为向量表示能够捕捉语义信息并应用于机器学习模型中。本文将介绍词向量的基本概念,通过从零实现Word2Vec模型帮助读者掌握词向量的生成过程。同时,本文…

【软件测试】设计测试用例的万能公式

文章目录 概念设计测试用例的万能公式常规思考逆向思维发散性思维万能公式水杯测试弱网测试如何进行弱网测试 安装卸载测试 概念 什么是测试用例? 测试⽤例(Test Case)是为了实施测试⽽向被测试的系统提供的⼀组集合,这组集合包…

Redis性能优化——针对实习面试

目录 Redis性能优化什么是bigkey?bigkey的危害?如何处理bigkey?什么是hotkey?hotkey的危害?如何处理hotkey?如何处理大量key集中过期问题?什么是内存碎片?为什么会有Redis内存碎片?…

CTF-Hub SQL 字符型注入(纯手动注入)

题目很明确是字符型注入,所有先尝试单引号 由于输入1 出现页面错误,且1不会出现页面错误,推断出该 sql 语句是使用单引号进行闭合的。(因为题目比较简单,已经把执行的 sql 语句一同打印在了底下) 开始注入(…

鸿蒙中服务卡片数据的获取和渲染

1. 2.在卡片中使用LocalStorageProp接受传递的数据 LocalStorageProp("configNewsHead") configNewsHeadLocal: ConfigNewsHeadInfoItem[] [] 注意:LocalStorageProp括号中的为第一步图片2中的键 3.第一次在服务卡片的第一个卡片中可能会获取不到数据…

Java Servlet 详解

一、Servlet的基本概念 Servlet(Server Applet)是Java Servlet的简称,是Java语言编写的服务器端程序。Servlet主要用于处理HTTP请求和生成HTTP响应,可以完成B/S架构下客户端请求的响应处理,交互式地浏览和生成数据&am…