用PyTorch 从零开始构建 BitNet 1.58bit

devtools/2024/9/23 9:35:37/

我们手动实现BitNet的编写,并进行的一系列小实验证实,看看1.58bit 模型是否与全精度的大型语言模型相媲美!

什么是量化以及为什么需要它?

量化是用更少的比特数表示浮点数的过程。当两个数字使用不同的比特数进行量化时,浮点运算的计算成本几乎按照减少的比特数的比例降低(理论上)。这使我们能够提高速度并减少机器学习模型的内存消耗。但这通常会导致信息丢失,从而降低准确性,我们可以通过对量化模型进行更多的微调来一定程度上恢复这种损失。

现有的量化方法与 BitNet 1.58bit 对比

大多数量化算法都需要一个全精度的预训练模型。人们通常会应用后训练量化(PTQ)和量化感知训练(QAT)等技术,以使这些算法有效运行。

PTQ 是一种量化技术,模型在训练完成后进行量化。QAT 是对 PTQ 模型的进一步微调,即在考虑量化的情况下进一步训练模型。

而BitNet 采用了一种截然不同的方法,即从头开始训练模型时就进行量化!

BitNet 的量化算法

上图中,通过取绝对值的平均值的一半(假设 n=2)来计算权重裁剪阈值 γ。然后,权重矩阵 W 被相同的值除,导致新的权重矩阵在原始权重值 ≥ γ 时的值 ≥ 1,原始权重值 ≤ -γ 时的值 ≤ -1。对于 -γ 和 γ 之间的值,它们被映射到 -0.99999… 到 0.9999…

当执行 roundclip 时,

对于原始值 ≥ γ,新值为 1.0,原始值 ≤ -γ,新值为 -1.0,原始值在 -γ 和 γ 之间的新值为 0.0。

理论上,结果值可以用信息编码理论表示为 1.58 位。由于位数不能是分数,我们可以用 2 位来表示。

量化函数在Pytorch中的实现

阈值计算:

 def compute_adjustment_factor(self, input_tensor: torch.Tensor):absmean_weight = torch.mean(torch.abs(input_tensor))adjustment_factor = 1e-4 + absmean_weight * 2 # 1e-4 to avoid zero divison errorreturn adjustment_factor

这里没有把绝对值减半,而是乘以了2。但是实验还是成功了!

RoundClip (1.58~= 2bit)

 def compute_2bit_quantized_tensor(self, input_tensor: torch.Tensor):twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1)return twobit_matrixdef compute_1bit_quantized_tensor(self, input_tensor: torch.Tensor):return torch.sign(input_tensor)def compute_quantized_tensor(self, input_tensor: torch.Tensor):if self.quantization_mode == QuantizationMode.two_bit:return self.compute_2bit_quantized_tensor(input_tensor)else:return self.compute_1bit_quantized_tensor(input_tensor)

量化步骤

 weight_adjustment_factor = self.compute_adjustment_factor(self.weight)adjusted_weight = self.weight / weight_adjustment_factorquantized_weight = self.compute_quantized_tensor(adjusted_weight)

线性层操作

 F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)

将调整因子与输入相乘,并将其除以量化权重

如果在将权重传递给线性层函数之前对其进行量化,则对量化矩阵的更新不会通过量化函数(因为大多数更新将在1e-4到1e-2之间,当通过量化步骤反向传播时将变为零)。因为原始的权重矩阵永远不会更新,模型永远不会学习!!

但有一个巧妙的工程技巧可以做到这一点,完整的前向传播是这样的

 def forward(self, x):weight_adjustment_factor = self.compute_adjustment_factor(self.weight)adjusted_weight = self.weight / weight_adjustment_factorif self.training:quantized_weight = (adjusted_weight+ (self.compute_quantized_tensor(adjusted_weight) - adjusted_weight).detach())else:quantized_weight = self.compute_quantized_tensor(adjusted_weight)return F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)

量化权重块的值无论

self.training

是否设置为

True

都是相同的。但是当

self.training

设置为

True

时,计算得到的梯度会被优雅地复制到调整后的权重中。这允许在训练过程中更新调整后的权重,同时也更新原始的权重矩阵。

这是从谷歌 DeepMind 的 VQ VAE PyTorch 实现中借鉴的简单却实用的技巧

自定义Pytorch实现的实验结果

下面的实验选择了一个小型模型和一个相对于小型模型来假设足够大的数据集。此外,为了创建目标模型的量化变体,我简单地使用以下代码块,将

nn.Linear

模块替换为这个自定义实现:

 import copydef create_quantized_copy_of_model(input_model: nn.Module, quantization_mode: QuantizationMode):model_copy = copy.deepcopy(input_model)hash_table = {n: m for n, m in model_copy.named_modules()}for key in list(hash_table.keys()):if isinstance(hash_table[key], nn.Linear):new_module = BitNetLinearLayer(in_features=hash_table[key].in_features,out_features=hash_table[key].out_features,bias=hash_table[key].bias is not None,quantization_mode=quantization_mode,)name_chain = key.split(".")parent_module_attr_name = ".".join(name_chain[:-1])parent_module = hash_table[parent_module_attr_name]setattr(parent_module, name_chain[-1], new_module)for n, m in model_copy.named_modules():assert not isinstance(m, nn.Linear)return model_copy

结果如下:

4层FFN的Mnist结果 :

128维6层VIT版本训练Fashion MNIST的结果

128维8层VIT在 CIFAR100上的结果

我们可以看到,除了第一个实验外,2位和1位版本的模型与全精度的常规版本的模型表现得一样好。在第一个实验中,量化模型可能发生了灾难性遗忘。

这些实验并未使用大型语言模型(LLMs)进行,但足以证明论文关于这样的系统能与全精度模型竞争的说法。

我们的实验与论文的唯一一个区别是,这个实现并没有将量化权重存储在2位矩阵中,计算仍以fp32执行的,要真正看到计算速度的提升,需要为此专门的计算内核,我们目前没有能力编写,所以实现仅验证了论文的潜在的论点。

以上实验的所有代码和模块代码都可以在github repo中找到

https://avoid.overfit.cn/post/131875e588ac4f4aa4f15d2dfa5b46db

作者:Chidhambararajan R


http://www.ppmy.cn/devtools/88834.html

相关文章

Day12--Servlet实现前后端交互(案例:学生信息管理系统登录页面)

(在一个完整的项目架构中,servlet的角色和位置) Servlet、GenericServlet和HttpServlet三者之间的关系是Java Web开发中的一个重要概念,它们共同构成了基于Java的服务器端程序的基础。以下是具体分析: 1. Servlet接口…

多语言海外AEON抢单可连单加额外单源码,java版多语言抢单系统

多语言海外AEON抢单可连单加额外单源码,java版多语言抢单系统。此套是全新开发的java版多语言抢单系统。 后端java,用的若依框架,这套代码前后端是编译后的,测试可以正常使用,语言繁体,英文,日…

Laravel php框架与Yii php 框架的优缺点

Laravel和Yii都是流行的PHP框架,它们各自具有独特的优点和缺点。以下是对这两个框架优缺点的详细分析: Laravel PHP框架的优缺点 优点 1、设计思想先进:Laravel的设计思想非常先进,非常适合应用各种开发模式,如TDD&…

【Java】Java学生成绩管理系统(源码+论文)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

构造函数或者析构函数中调用虚函数会怎样

构造函数或者析构函数中调用虚函数会怎样 析构函数中调用虚函数总结 在C中,构造函数和析构函数中调用虚函数会导致一些行为与预期不符。具体来说: 构造函数中调用虚函数 当在构造函数中调用虚函数时,调用的是当前类的版本,而不是…

【算法设计题】实现以字符串形式输入的简单表达式求值,第2题(C/C++)

目录 第2题 实现以字符串形式输入的简单表达式求值 得分点(必背) 题解 1. 初始化和变量定义 2. 获取第一个数字并存入队列 3. 遍历表达式字符串,处理运算符和数字 4. 初始化 count 并处理加减法运算 代码详解 🌈 嗨&#xf…

hive自动安装脚本

使用该脚本注意事项 安装hive之前确定机子有网络。或者yum 更改为本地源,因为会使用epel仓库下载一个pv的软件使用该脚本前提是自行安装好mysql数据库准备好tomcat软件包,该脚本使用tomcat9.x版本测试过能正常执行安装成功,其他版本没有测试…

新书推荐:《码农职场:IT 人求职就业手册》——照亮你的职业道路

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…