Dora原理代码讲解

ops/2024/10/15 22:12:09/

Dora原理及代码讲解

相关的完整训练应用及验证代码在:
https://github.com/mst272/LLM-Dojo 项目的dora部分:Dora部分

前验证知识:Lora

Lora估计大家已经很熟悉了,在这里就不详细介绍Lora的一些原理的,简而言之就是冻结已经训练好的模型权重,在其线性层或其它层中增加一些参数,只训练这些新增的参数层。
可以直接用这一个图来表示:
在这里插入图片描述

直接用代码来解释一下:

创建LoraLayer

python">import torch.nn as nn
import torch# 构建LoraLayer
class LoRALayer(nn.Module):def __init__(self, in_dim, out_dim, rank,  alpha):super().__init__()std_dev = 1 / torch.sqrt(torch.tensor(rank).float())self.A = nn.Parameter(torch.rand(in_dim, rank)*std_dev)self.B = nn.Parameter(torch.zeros(rank, out_dim))self.alpha = alphadef forward(self, x):x = self.alpha * (x @ self.A @ self.B)return x

将Lora合并到线性层

python">class LinearWithLoRA(nn.Module):def __init__(self, linear, rank, alpha):super().__init__()self.linear = linearself.lora = LoRALayer(linear.in_features,linear.out_features,rank,alpha)def forward(self,x):return self.linear(x) + self.lora(x)

Lora的变体—Dora

具体原理可以在代码中理解,实际运用中将LinearWithLoRA替换为LinearWithDoRA即可使用。

python">class LinearWithDoRA(nn.Module):def __init__(self, linear, rank, alpha):super().__init__()self.linear = linearself.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)self.m = nn.Parameter(torch.ones(1, linear.out_features))def forward(self, x):linear_out = self.linear(x)lora_out = self.lora(x)lora_out_norm = lora_out / (lora_out.norm(p=2, dim=1, keepdim=True) + 1e-9)dora_modification = self.m * lora_out_normreturn linear_out + dora_modification   

具体如何使用我们可以自己简单的建一个Model进行测试,例如:

python">class TestMLP(nn.Module):def __init__(self, num_features, num_hidden1, num_hidden2, num_class):super().__init__()self.layers = nn.Sequential(nn.Linear(num_features, num_hidden1),nn.ReLU(),nn.Linear(num_hidden1, num_hidden2),nn.ReLU(),nn.Linear(num_hidden2, num_class))def forward(self, x):x = self.layers(x)return x

具体完整的模型训练集验证可见文章开头的代码库,其中有完整的代码演示。


http://www.ppmy.cn/ops/23314.html

相关文章

算法人生(13):从“Scrum”看“PDCA时间管理法”

很多人会好奇为什么“读了很多书,却依然不知道怎么过好这一生”?大家可能都有各自的理解,但正如王阳明先生的“知行合一”所说,“知”要能“行”出来才算“真知”,生活中很多时候知并不一定能行,所以知与行…

Windows 环境部署 ChatGLM2-6b 入门教程

介绍 ChatGLM2-6B是智谱AI及清华KEG实验室发布的中英双语对话模型,它是 ChatGLM-6B 的第二代版本。 主要特点: 性能提升:ChatGLM2-6B 在初代模型的基础上进行了全面升级,使用了 GLM 的混合目标函数,并经过了 1.4T 中…

elementUi中el-date-picker;两个日期选择器第二个必须在第一个之后

<el-row><el-col :span"12"><el-form-item label"实际开始日期" style"margin-top: 10px;" proprealBeginDate><el-date-picker v-model"pmTaskProgressFeedback.realBeginDate" type"date" placehold…

AIGC技术带来的安全与隐私问题探讨

如何看待AIGC技术&#xff1f; 简介&#xff1a;探讨AIGC技术的发展现状和未来趋势。提醒&#xff1a;在发布作品前&#xff0c;请把不需要的内容删掉。 方向一&#xff1a;技术应用 机遇和挑战 AIGC国内场景应用图谱 方向二&#xff1a;伦理与风险 垄断与隐私风险 AI民主化诉…

python 调用 llama

参考&#xff1a; https://blog.51cto.com/u_16175437/9317548 方法一&#xff1a; 要在Python中调用Llama.ai模型来生成回答&#xff0c;你可以使用transformers库&#xff0c;它提供了调用不同的预训练模型的接口。以下是一个简单的例子&#xff0c;展示了如何使用transform…

疯狂的爬虫案例(2)文末附源码

软件版本号&#xff1a; python --version Python 3.8.0 pip show selenium Version: 4.20.0 chromedriver.exe -version 109.0.5414.74 主题&#xff1a;爬取10条动态网页内容&#xff08;电影票房&#xff09; 1.根据xpath获取网页节点&#xff08;CtrlF&#xff09; 2.…

javaWeb项目-医药进出口交易系统功能介绍

项目关键技术 开发工具&#xff1a;IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架&#xff1a;ssm、Springboot 前端&#xff1a;Vue、ElementUI 关键技术&#xff1a;springboot、SSM、vue、MYSQL、MAVEN 数据库工具&#xff1a;Navicat、SQLyog 1、Java技术 Java是…

Android kotlin 协程异步async与await介绍与使用

一、介绍 在kotlin语言中&#xff0c;协程是一个处理耗时的操作&#xff0c;但是很多人都知道同步和异步&#xff0c;但是不知道该如何正确的使用&#xff0c;如果处理不好&#xff0c;看似异步&#xff0c;其实在runBloacking模块中使用的结果是同步的。 针对如何同步和如何异…