QLoRA 微调Qwen1.5-0.5B-Chat

embedded/2024/9/23 3:37:58/

参考文章:

https://huggingface.co/blog/4bit-transformers-bitsandbytes

https://github.com/artidoro/qlora/tree/main

 本文实战使用QLoRA技术微调阿里的Qwen1.5-0.5B-Chat模型,采用single-gpu 进行训练。

 1. 核心Python包【python版本:3.10.0】

  • torch  2.2.2+cu118
  • accelerate   0.33.0
  • bitsandbytes  0.43.3
  • transformers   4.37.0

2. 使用数据集

https://github.com/DB-lost/self-llm/blob/master/dataset/huanhuan.json

3. 具体实现代码

# coding:utf-8
"""QLoRA Finetune Qwen1.5-0.5B-Chat"""from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig
from torch.utils.data import Dataset
import torch
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from typing import Dict
import transformers
import json
from transformers.trainer_pt_utils import LabelSmootherIGNORE_TOKEN_ID = LabelSmoother.ignore_indexmax_len = 512
data_json = json.load(open("./data/huanhuan.json", 'r', encoding='utf-8'))
train_json = []
lazy_preprocess = True
gradient_checkpointing = True
TEMPLATE = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"def print_model_allarguments_name_dtype(model):for n, v in model.named_parameters():if v.requires_grad:print(f"trainable model arguments:{n}--{v.dtype}--{v.shape}")else:print(f"not trainable model arguments:{n}--{v.dtype}--{v.shape}")config = AutoConfig.from_pretrained("./models/Qwen1.5-0.5B-Chat",trust_remote_code=True)# kv cache 在推理的时候才用,训练时候不用
config.use_cache = Falsetokenizer = AutoTokenizer.from_pretrained("./models/Qwen1.5-0.5B-Chat",model_max_length=max_len,padding_side="right",use_fast=False
)model = AutoModelForCausalLM.from_pretrained("./models/Qwen1.5-0.5B-Chat",torch_dtype=torch.bfloat16,device_map="auto",quantization_config=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,),config=config,low_cpu_mem_usage=True
)print("Original Model: ")
print_model_allarguments_name_dtype(model)model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
print("kbit training: ")
print_model_allarguments_name_dtype(model)config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],r=64, # Lora 秩lora_alpha=16, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05, # Dropout 比例bias='none'
)
model = get_peft_model(model, config)
print("LoRA Model: ")
print_model_allarguments_name_dtype(model)
model.print_trainable_parameters()"""
这个函数调用启用了模型的梯度检查点。
梯度检查点是一种优化技术,可用于减少训练时的内存消耗。
通常,在反向传播期间,模型的中间激活值需要被保留以计算梯度。
启用梯度检查点后,系统只需在需要时计算和保留一部分中间激活值,从而减少内存需求。
这对于处理大型模型或限制内存的环境中的训练任务非常有用。
"""
if gradient_checkpointing:model.enable_input_require_grads()def preprocess(messages,tokenizer: transformers.PreTrainedTokenizer,max_len: int,
) -> Dict:"""Preprocesses the data for supervised fine-tuning."""texts = []for i, msg in enumerate(messages):texts.append(tokenizer.apply_chat_template(msg,chat_template=TEMPLATE,tokenize=True,add_generation_prompt=False,padding=True,max_length=max_len,truncation=True,))input_ids = torch.tensor(texts, dtype=torch.long)target_ids = input_ids.clone()target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_IDattention_mask = input_ids.ne(tokenizer.pad_token_id)return dict(input_ids=input_ids, target_ids=target_ids, attention_mask=attention_mask)class LazySupervisedDataset(Dataset):"""Dataset for supervised fine-tuning."""def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):super(LazySupervisedDataset, self).__init__()self.tokenizer = tokenizerself.max_len = max_lenself.tokenizer = tokenizerself.raw_data = raw_dataself.cached_data_dict = {}def __len__(self):return len(self.raw_data)def __getitem__(self, i) -> Dict[str, torch.Tensor]:if i in self.cached_data_dict:return self.cached_data_dict[i]ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)ret = dict(input_ids=ret["input_ids"][0],labels=ret["target_ids"][0],attention_mask=ret["attention_mask"][0],)self.cached_data_dict[i] = retreturn retclass SupervisedDataset(Dataset):"""Dataset for supervised fine-tuning."""def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):super(SupervisedDataset, self).__init__()sources = [example["conversations"] for example in raw_data]data_dict = preprocess(sources, tokenizer, max_len)self.input_ids = data_dict["input_ids"]self.labels = data_dict["labels"]self.attention_mask = data_dict["attention_mask"]def __len__(self):return len(self.input_ids)def __getitem__(self, i) -> Dict[str, torch.Tensor]:return dict(input_ids=self.input_ids[i],labels=self.labels[i],attention_mask=self.attention_mask[i],)for i, d in enumerate(data_json):t = {"id": f"identity_{i}","conversations": [{"role": "user","content": d['instruction'] + d['input']},{"role": "assistant","content": d['output']}]}train_json.append(t)dataset_cls = (LazySupervisedDataset if lazy_preprocess else SupervisedDataset
)train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)
eval_dataset = None
data_module = dict(train_dataset=train_dataset, eval_dataset=eval_dataset)args = TrainingArguments(output_dir="./output/Qwen1.5",per_device_train_batch_size=2,per_device_eval_batch_size=1,gradient_accumulation_steps=8,logging_steps=10,weight_decay=0.01,adam_beta2=0.95,num_train_epochs=5,save_steps=100,learning_rate=3e-4,save_on_each_node=True,gradient_checkpointing=True,lr_scheduler_type='cosine',warmup_ratio=0.01
)
trainer = Trainer(model=model,args=args,**data_module
)
trainer.train()

4. 训练及推理

具体可以参考本人文章:
基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat-CSDN博客

5. 具体效果


http://www.ppmy.cn/embedded/91100.html

相关文章

【机器学习】神经网络的无限可能:从基础到前沿

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 引言 在当今人工智能的浪潮中&#xff0c;神经网络作为其核心驱动力之一&#xff0c;正以前所未有的速度改变着我们的世界。从图像识别到自然语言处理&#xff0c;从自动驾驶到医疗诊断&#xff0c;神经网络的…

算法经典题目:2Sum

题目 有两个非空的链表&#xff0c;每个链表代表一个非负整数。这些数字的位数是以逆序存储的&#xff08;即个位在链表头部&#xff09;&#xff0c;并且每个节点包含一个单独的数字。你的任务是将这两个数字相加&#xff0c;并将结果以同样的链表形式返回。 示例&#xff1…

Axure中文版资源免费下载!

Axure是一种专业的原型设计工具&#xff0c;可以帮助用户以最快的速度将产品想法转化为可视化原型&#xff0c;为设计师、产品经理和开发人员之间的沟通搭建桥梁。Axure功能强大&#xff0c;可绘制高保真原型、建立动态面板、使用复杂函数库、多人合作设计、标准化导出等功能&a…

学生信息管理系统(Python+PySimpleGUI+MySQL)

吐槽一下 经过一段时间学习pymysql的经历&#xff0c;我深刻的体会到了pymysql的不靠谱之处&#xff1b; 就是在使用int型传参&#xff0c;我写的sql语句中格式化%d了之后&#xff0c;我在要传入的数据传递的每一步的去强制转换了&#xff0c;但是他还是会报错&#xff0c;说我…

服务端开发常用知识(持续更新中)

Java方面 1 基础篇 1.1 网络基础 1.1.1 tcp三次握手 TCP协议使用三次握手&#xff08;Three-Way Handshake&#xff09;来建立一个可靠的连接&#xff0c;这是为了确保双方都能同步并且确认连接的有效性。让我们详细解释为什么三次握手是必要的&#xff0c;以及如果只用两次…

【JDK】JDK环境配置踩坑记录Mac

万事胜意哟 首先&#xff0c;确定我们已经下载并安装了JDK8&#xff0c;这里没完成的&#xff0c;可以搜一下下载JDK的步骤 编辑配置环境变量 open -e ~/.bash_profile在打开的配置文件中&#xff0c;添加以下行来设置JAVA_HOME环境变量&#xff0c;并更新PATH变量&#xff1…

C++面试---小米

一、static 关键字的作用&#xff0c;及和const的区别 static关键字作用&#xff1a; 1、在类的成员变量前使用&#xff0c;表示该变量属于类本身&#xff0c;而不是任何类的实例。 2、在类的成员函数前使用&#xff0c;表示该函数不需要对象实例即可调用&#xff0c;且只能访问…

生活需要BGM,悠律凝声环开放式耳机全场景通用

如今&#xff0c;BGM围绕着我们的生活&#xff0c;音乐是生活的调料品&#xff0c;深受运动爱好者的喜爱&#xff0c;不但能够缓解锻炼时的单调&#xff0c;也能够更好地激发我们的身体状态。最近我入手的悠律凝声环ringbuds pro就是这样一款特别适合运动场景使用。 开放式耳机…