使用Chatglm-6b微调催收问答对的尝试

news/2025/1/25 8:47:36/

1.工作目录,如:mnt/d/work,下载源代码,并安装依赖

git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
pip install -r requirement.txt

2. 从拥抱脸下载chatglm-6b-int4-qe到本地(GPU环境搭建参考浪潮服务器 Tesla T4 16G GPU 环境配置_wxl781227的博客-CSDN博客)

3. 安装docker及nvidia-docker2环境(自行百度知乎)

4. docker下载nvidia容器(强烈推荐使用nvidia容器,省去很多麻烦)

sudo docker pull nvcr.io/nvidia/pytorch:23.01-py3

5. 在工作目录下,启动nvidia容器,多开放几个端口,方便映射。

sudo docker run --name gpt --gpus=all --ipc=host --ulimit memlock=-1 --rm -it -p 6006:6006 -p 8500:8500 -p 8501:8501 -p 8502:8502 -v $PWD:/gpt -v /etc/localtime:/etc/localtime:ro  -d nvcr.io/nvidia/pytorch:23.01-py3

6. 整理催收问答对训练train.json及验证文件val.json

7. 在gpt容器中启动jupyter lab,方便修改代码

jupyter lab --port=6006 --allow-root

8. 打开localhost:6006/lab,输入token,这样就可以方便的上传文件,修改代码了。

 9. 进入ptuning修改训练脚本train.sh

PRE_SEQ_LEN=128
LR=2e-2CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file /gpt/ChatGLM-6B/train.json \--validation_file /gpt/ChatGLM-6B/val.json \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /gpt/chatglm-6b-int4-qe \--output_dir check_point/chatglm-6b-int4-qe-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 16 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4

per_device_train_batch_size根据GPU大小进行调整,如:1,2,4,8,16等。

max_source_length及max_target_length 可以根据实际情况调整,对应的是输出和输出的长度。

10.在gpt容器中启动微调训练,根据数据量不同及GPU大小,有所不同,tesla T4 16 G 30个问答对大概要24小时。

sh train.sh

11. 评估训练的效果(修改evaluate.sh,并执行)

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_predict \--validation_file /gpt/val.json \--test_file /gpt/val.json \--overwrite_cache \--prompt_column content \--response_column summary \--model_name_or_path /gpt/chatglm-6b-int4-qe \--ptuning_checkpoint check_point/$CHECKPOINT/checkpoint-$STEP \--output_dir output/$CHECKPOINT \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_eval_batch_size 1 \--predict_with_generate \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
sh evaluate.sh

12. 启动微调后的模型(修改web_demo.sh,并运行,默认端口7860,web_demo.py中修改为server_port=8500)

# web_demo.py
# demo.queue().launch(share=False, inbrowser=True)
demo.queue().launch(share=False, inbrowser=True, server_port=8500)
PRE_SEQ_LEN=128CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \--model_name_or_path /gpt/chatglm-6b-int4-qe \--ptuning_checkpoint check_point/chatglm-6b-int4-qe-pt-128-2e-2/checkpoint-3000 \--pre_seq_len $PRE_SEQ_LEN
sh web_demo.sh

13. 打开界面实际测试

 

14.结论

微调后的效果与预期一致,能够有效控制输出结果,即输出完全按照问答对训练预料来输出的。


http://www.ppmy.cn/news/67132.html

相关文章

9.设计模式之适配器模式

前言 适配器模式是我们在很多框架中会见到的设计模式。它主要解决的是两种接口不兼容而不能一起使用的问题。通过适配器,将原来的接口适配成目标接口,在不改动老代码的同时,实现两者同时使用。生活中常见的例子比如手机电源充电适配器。 本…

单片机GD32F303RCT6 (Macos环境)开发 (十七)—— i2c1从机中断接收发送数据

i2c1从机中断接收发送数据 1、将i2c1设置为从机模式,与树莓派连接。树莓派发送或者读取数据,gd32中断触发,从而接收数据或者向主机发送数据。 2、关于代码的宏定义配置 Application目录的Makefile中 ENABLE_I2C_TEST yes才会编译I2C1的相关…

Springboot +Flowable,设置流程变量的方式(二)

一.简介 为什么需要流程变量。 举个例子,假设有如下一个流程,截图如下: 这是一个请假流程,那么谁请假、请几天、起始时间、请假理由等等,这些都需要说明,不然领导审批的依据是啥?那么如何传递…

开源项目九死一生,但很多程序员坚持开源??

大家好,欢迎来到停止重构的频道。 本期我们讨论一个开放问题。 为什么流行的开源项目只是凤毛麟角,且很多有名的开源项目都是背靠大公司的。 但是,为什么还有很多个人开发者愿意开源项目呢? 欢迎大家把自己的想法或开源项目发…

大语言模型Prompt工程之使用GPT3.5生成图数据库Cypher

大语言模型Prompt工程之使用GPT3.5生成图数据库Cypher 大语言模型Prompt工程之使用GPT3.5生成图数据库Cypher Here’s the table of contents: 大语言模型Prompt工程之使用GPT3.5生成图数据库Cypher 使用GPT3.5测试了生成Cypher的能力,相比于GPT4生成Cypher的能力&a…

电脑死机的常用排查思路

在工作过程中难免会遇到死机的问题,排查起来并不是那么轻松,下面分享一下我排查死机问题的思路。 判断是软件还是硬件级别的故障 在死机时先尝试移动鼠标,按大小写切换键或数字键盘锁定键,看看光标是否可以移动,大小…

【地铁上的设计模式】--行为型模式:备忘录模式

什么是备忘录模式 备忘录模式(Memento Pattern)是一种行为型设计模式,其目的是在不破坏封装性的前提下,捕获一个对象的内部状态,并在该对象之外保存该状态,以便之后恢复对象到该状态。该模式可以使得对象的…

大数据Doris(十六):分桶Bucket和分区、分桶数量和数据量的建议

文章目录 分桶Bucket和分区、分桶数量和数据量的建议 一、分桶Bucket