LoRA微调大语言模型Bert

LoRA是一种流行的微调大语言模型的手段,这是因为LoRA仅需在预训练模型需要微调的地方添加旁路矩阵。LoRA 的作者们还提供了一个易于使用的库 loralib,它极大地简化了使用 LoRA 微调模型的过程。这个库允许用户轻松地将 LoRA 层添加到现有的模型架构中,而无需深入了解其底层实现细节。这使得 LoRA 成为了一种非常实用的技术,既适合研究者也适合开发人员。下面给出了一个LoRA微调Bert模型的具体例子。
下图给出了一个LoRA微调Bert中自注意力矩阵 W Q W^Q WQ的例子。如图所示,通过冻结矩阵 W Q W^Q WQ,并且添加旁路低秩矩阵 A , B A,B A,B来进行微调。同理,使用LoRA来微调 W K W^K WK也是如此。
image.png
我们给出了通过LoRA来微调Bert模型中自注意力矩阵的具体代码。代码是基于huggingface中Bert开源模型进行改造。Bert开源项目链接如下:
https://huggingface.co/transformers/v4.3.3/_modules/transformers/models/bert/modeling_bert.html

基于LoRA微调的代码如下:
# 环境配置
# pip install loralib
# 或者
# pip install git+https://github.com/microsoft/LoRA
import loralib as loraclass LoraBertSelfAttention(BertSelfAttention):"""继承BertSelfAttention模块对Query,Value用LoRA进行微调参数:- r (int): LoRA秩的大小- config: Bert模型的参数配置"""def __init__(self, r=8, *config):super().__init__(*config)# 获得所有的注意力的头数d = self.all_head_size # 使用LoRA提供的库loralibself.lora_query = lora.Linear(d, d, r)self.lora_value = lora.Linear(d, d, r)def lora_query(self, x):"""对Query矩阵执行Wx + BAx操作"""return self.query(x) + F.linear(x, self.lora_query)def lora_value(self, x):"""对Value矩阵执行Wx + BAx操作"""return self.value(x) + F.linear(x, self.lora_value)def forward(self, hidden_states, *config):"""更新涉及到Query矩阵和Value矩阵的操作"""# 通过LoRA微调Query矩阵mixed_query_layer = self.lora_query(hidden_states)is_cross_attention = encoder_hidden_states is not Noneif is_cross_attention and past_key_value is not None:# reuse k,v, cross_attentionskey_layer = past_key_value[0]value_layer = past_key_value[1]attention_mask = encoder_attention_maskelif is_cross_attention:key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))attention_mask = encoder_attention_maskelif past_key_value is not None:key_layer = self.transpose_for_scores(self.key(hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))key_layer = torch.cat([past_key_value[0], key_layer], dim=2)value_layer = torch.cat([past_key_value[1], value_layer], dim=2)else:key_layer = self.transpose_for_scores(self.key(hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)if self.is_decoder:past_key_value = (key_layer, value_layer)# Query矩阵与Key矩阵算点积得到注意力分数attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":seq_length = hidden_states.size()[1]position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)distance = position_ids_l - position_ids_rpositional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibilityif self.position_embedding_type == "relative_key":relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)attention_scores = attention_scores + relative_position_scoreselif self.position_embedding_type == "relative_key_query":relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_keyattention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:attention_scores = attention_scores + attention_maskattention_probs = nn.Softmax(dim=-1)(attention_scores)attention_probs = self.dropout(attention_probs)if head_mask is not None:attention_probs = attention_probs * head_maskcontext_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)if self.is_decoder:outputs = outputs + (past_key_value,)return outputsclass LoraBert(nn.Module):def __init__(self, task_type, num_classes=None, dropout_rate=0.1, model_id="bert-base-cased",lora_rank=8, train_biases=True, train_embedding=False, train_layer_norms=True):"""- task_type: 设计任务的类型,如:'glue', 'squad_v1', 'squad_v2'.- num_classes: 分类类别的数量.- model_id: 预训练好的Bert的ID,如:"bert-base-uncased","bert-large-uncased".- lora_rank: LoRA秩的大小.- train_biases, train_embedding, train_layer_norms: 这是参数是否需要训练    """super().__init__()# 1.加载权重self.model_id = model_idself.tokenizer = BertTokenizer.from_pretrained(model_id)self.model = BertForPreTraining.from_pretrained(model_id)self.model_config = self.model.config# 2.添加模块d_model = self.model_config.hidden_sizeself.finetune_head_norm = nn.LayerNorm(d_model)self.finetune_head_dropout = nn.Dropout(dropout_rate)self.finetune_head_classifier = nn.Linear(d_model, num_classes)# 3.通过LoRA微调模型self.replace_multihead_attention()self.freeze_parameters()def replace_self_attention(self, model):"""把预训练模型中的自注意力换成自己定义的LoraBertSelfAttention"""for name, module in model.named_children():if isinstance(module, RobertaSelfAttention):layer = LoraBertSelfAttention(r=self.lora_rank, config=self.model_config)layer.load_state_dict(module.state_dict(), strict=False)setattr(model, name, layer)else:self.replace_self_attention(module)def freeze_parameters(self):"""将除了涉及LoRA微调模块的其他参数进行冻结LoRA微调影响到的模块: the finetune head, bias parameters, embeddings, and layer norms """for name, param in self.model.named_parameters():is_trainable = ("lora_" in name or"finetune_head_" in name or(self.train_biases and "bias" in name) or(self.train_embeddings and "embeddings" in name) or(self.train_layer_norms and "LayerNorm" in name))param.requires_grad = is_trainablepeft库中包含了LoRA在内的许多大模型高效微调方法,并且与transformer库兼容。使用peft库对大模型flan-T5-xxl进行LoRA微调的代码例子如下:
# 通过LoRA微调flan-T5-xxl
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# 模型介绍:https://huggingface.co/google/flan-t5-xxl
model_name_or_path = "google/flan-t5-xxl"model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
peft_config = LoraConfig(r=8,lora_alpha=16, target_modules=["q", "v"], # 仅对Query,Value矩阵进行微调lora_dropout=0.1,bias="none", task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, peft_config)
# 打印可训练的参数
model.print_trainable_parameters()

http://www.ppmy.cn/server/101544.html

相关文章

Docker 部署 XXL-JOB

Docker 部署 XXL-JOB 目录 引言环境准备创建 MySQL 用户并授予权限使用 Docker 部署 XXL-JOB配置 XXL-JOB验证部署总结 1. 引言 XXL-JOB 是一个开源的分布式任务调度平台,旨在简化定时任务的管理和调度操作。其强大的功能和灵活性,使其在互联网公司和…

【python制作一个小程序作为七夕礼物】

制作一个七夕节礼物的小程序,我们可以考虑一个简单的互动程序,比如一个“七夕情侣姓名配对指数计算器”。这个程序将接收两个名字作为输入,然后输出一个随机的“配对指数”和一些浪漫的话语。以下是一个使用Python实现的简单示例:…

无字母绕过webshell

目录 代码 payload构造 php7 php5 构造payload 代码 不可以使用大小写字母、数字和$然后实现eval的注入执行 <?php if(isset($_GET[code])){$code $_GET[code];if(strlen($code)>35){die("Long.");}if(preg_match("/[A-Za-z0-9_$]/",$code))…

学习C语言第十五天

第一项 C 字符串 字符串实际上是使用空字符 \0 结尾的一维字符数组&#xff0c;\0 是用于标记字符串的结束。 空字符&#xff08;Null character&#xff09;又称结束符&#xff0c;缩写 NUL&#xff0c;是一个数值为 0 的控制字符&#xff0c;\0 是转义字符&#xff0c;意思…

AI产品经理需要了解的算法知识

这篇文章给大家系统总结一下AI产品经理需要了解的算法知识。 1、自然语言生成&#xff08;NLG&#xff09; 自然语言生成&#xff08;Natural Language Generation&#xff0c;简称NLG&#xff09;是一种人工智能技术&#xff0c;它的目标是将计算机的数据、逻辑或算法产生的…

Nuxt3【服务器】server 详解

server 文件夹中的内容&#xff0c;会被自动注册为API和服务器处理程序。 服务器 API 对应路径 server/api server/api/hello.ts export default defineEventHandler((event) > {return {hello: world} })页面中使用 <script setup lang"ts"> const { da…

信号入门学习

1 信号入门 1.1 生活中的信号 音信号&#xff1a;如电话铃声、门铃、汽车喇叭声等。视觉信号&#xff1a;如交通信号灯、手语、交通标志、指示牌等。触觉信号&#xff1a;如触摸、压力变化、温度变化等。嗅觉信号&#xff1a;如花香、食物的香味、某些化学物质的气味等。味觉…

Python爬虫使用实例

IDE&#xff1a;大部分是在PyCharm上面写的 解释器装的多 → 环境错乱 → error&#xff1a;没有配置&#xff0c;no model 爬虫可以做什么&#xff1f; 下载数据【文本/二进制数据&#xff08;视频、音频、图片&#xff09;】、自动化脚本【自动抢票、答题、采数据、评论、点…

微信小程序免费《短视频去水印》

分享一个uniapp开发的微信小程序免费《短视频去水印》小程序 <template><view class"content"><view class"area-wrap"><textarea name"" v-model"state.content" maxlength"800" id"" cols…

微服务相关复习

目录 Spring Cloud 5大组件服务注册发现&#xff1b;nacos与eureka的区别负载均衡&#xff1b;Ribbon负载均衡策略&#xff1b;自定义负载均衡策略 服务雪崩、降级、熔断微服务是怎么监控的&#xff1f;有没有做过限流&#xff1f;怎么做CAP理论&#xff1b;BASE理论采用哪种分…

维基知识库系统Wiki.js本地Linux环境部署并配置公网地址远程访问

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【python学习】AC自动机 高效敏感词过滤与文本匹配:全面掌握pyahocorasick库 (NLP自然语言处理项目实战)

1. 引言 pyahocorasick是一个强大的Python库&#xff0c;用于实现Aho-Corasick算法&#xff0c;以便高效地进行多模式字符串匹配。在处理大规模文本或需要同时查找多个模式的应用场景中&#xff0c;pyahocorasick表现得尤为出色。本篇文章将不仅仅介绍该库的基础用法&#xff…

OBS混音器(Mixers)的重要性和配置指南

在进行直播或录制时,音频管理是非常关键的一环,特别是在需要同时处理多个音频源的复杂设置中。OBS Studio提供了强大的音频管理工具,其中“混音器”功能扮演了核心角色。混音器(Mixers)在OBS中用于控制不同音频源的输出路由,允许用户精确控制哪些音源出现在最终的直播或录…

2024年8月16日(运维自动化 ansible)

一、回顾 1、mysql和python (1)mysql5.7 1.1不需要执行mysql_ssl_rsa_setup 1.2change_master_to 不需要get public key (2)可以使用pymysql非交互的管理mysql 2.1pymysql.connect(host,user,password,database,port) 2.2 cursorconn.cursor() 2.3 cursor.execute("creat…

家里养有宠物浮毛多、异味大,宠物空气净化器有用吗

我家收养了12只流浪猫&#xff0c;掉毛量是很多人想象不到的&#xff0c;对于猫掉毛和人掉头发一个道理&#xff0c;情绪压力&#xff0c;长期熬夜&#xff0c;营养不良&#xff0c;年龄原因都会掉毛或掉头发&#xff0c;猫更是如此&#xff01;但确实之前也不知道一只猫的掉毛…

安全自动化和编排:如何使用自动化工具和编排技术来提高安全操作效率。(第一篇)

深入安全自动化与编排在DevSecOps中的应用&#xff08;第一篇&#xff09; 1. 引言 随着企业采用DevOps方法论&#xff0c;安全自动化与编排技术在提升安全操作效率方面变得至关重要。本文将深入探讨如何在DevSecOps流程中实施安全自动化与编排&#xff0c;结合具体的工具、技…

解决Vue2移动端(H5)项目,手机打开项目侧滑或者按物理返回键,始终是走this.$router.go(-1)

一、原因前言 最近开发Vue2移动端&#xff08;H5&#xff09;项目&#xff0c;用手机打开项目侧滑或者按物理返回键&#xff0c;始终是走this.$router.go(-1)&#xff0c;即相当于点击了浏览器的返回键的项目。目前想要的效果是&#xff1a;只要回到初始页面&#xff0c;点击返…

学懂C++(二十三):高级教程——深入详解C++ 标准库的多线程支持

目录 1. 创建、管理和操作线程&#xff1a;std::thread 2. 互斥量&#xff08;Mutex&#xff09; 3. 锁&#xff08;Lock&#xff09; 4. 条件变量&#xff08;Condition Variables&#xff09; 5. 原子操作&#xff08;Atomic Operations&#xff09; 6. 异步任务和 Fut…

RPA在政务领域的发展前景

随着信息技术的迅猛发展&#xff0c;政务领域也在不断探索创新&#xff0c;以提升政府服务的质量和效率。RPA作为一种自动化技术&#xff0c;打破了传统政务服务人工操作的局限&#xff0c;协助基层人员更高效准确地完成录入、审查、校对和数据汇总等各项繁琐的工作&#xff0c…

ISO 14229 1~7 pdf 标准下载见下面下载链接

https://download.csdn.net/download/xiaofei558008/89638262https://download.csdn.net/download/xiaofei558008/89638262