7. 探究模型参数与显存的关系以及不同精度造成的影响

news/2024/9/18 6:28:34/ 标签: 深度学习, 模型参数, 机器学习

这篇文章将探讨两个重点:

  • 模型参数与显存(GPU 内存)之间的关系
  • 不同精度的导入方式,以及它们对显存和性能的影响

理解这些概念会让你在模型的选择上更加游刃有余。

文章目录

  • 模型参数与显存的关系
  • 不同精度的导入方式及其影响
    • 常见的数值精度格式
    • 对显存占用的影响
  • 精度的权衡与选择
    • 准确性 vs. 性能
    • 何时选择何种精度
    • 硬件兼容性
  • 实际应用中的精度技巧
    • 使用 FP16 精度
    • 使用 BF16 精度
    • 使用 INT8 量化
    • 消除警告
  • 实践示例
    • 对比不同精度下的显存占用
    • 常见问题及解决方案
      • 问题一:`RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error`
      • 问题二:`TypeError: dispatch_model() got an unexpected keyword argument 'offload_index'`
  • 总结
  • 参考文献

模型参数与显存的关系

模型参数量与内存占用

神经网络模型由多个层组成,每一层都包含权重(weight)和偏置(bias),这些统称为模型参数。而模型的参数量一般直接影响它的学习和表示能力。

模型大小计算公式
模型大小(字节) = 参数数量 × 每个参数的字节数 \text{模型大小(字节)} = \text{参数数量} \times \text{每个参数的字节数} 模型大小(字节)=参数数量×每个参数的字节数

示例:

对于一个拥有 10 亿(1,000,000,000) 参数的模型,使用 32 位浮点数(float32 表示,每个参数占用 4 字节,即:
模型大小 = 1 , 000 , 000 , 000 × 4 字节 = 4 GB \text{模型大小} = 1,000,000,000 \times 4 \text{字节} = 4 \text{GB} 模型大小=1,000,000,000×4字节=4GB

具体来讲,以meta-llama/Meta-Llama-3.1-70B-Instruct 这个拥有 700 亿(70B) 参数的大模型为例,我们仅考虑模型参数,它的显存需求就已经超过了大多数消费级 GPU(如 RTX 4090 最高 48G):
70 × 1 0 9 × 4 字节 = 280 GB 70 \times 10^9 \times 4 \text{字节} = 280 \text{GB} 70×109×4字节=280GB

GPU 显存需求

而在实际部署模型时,GPU 不仅需要容纳模型参数,还需要处理其他数据,这意味着更大的显存占用量。其中包括:

  • 模型参数:模型的权重和偏置。
  • 优化器状态(仅训练时):如动量(momentum)和梯度平方和等信息。
  • 中间激活值:前向传播和反向传播过程中产生的中间结果。
  • 批量大小(Batch Size):一次处理的数据样本数量。

推理与训练的区别

  • 推理阶段:仅需加载模型参数和少量的中间激活值。
  • 训练阶段:需要额外存储梯度和优化器状态,因此显存需求更大。

不同精度的导入方式及其影响

为了降低显存占用,我们可以使用不同的数值精度格式来存储模型参数,这些精度格式在内存使用和计算性能上各有优劣。

常见的数值精度格式

  • FP32(32 位浮点数):标准精度,每个参数占用 4 字节
  • FP16(16 位浮点数):半精度浮点数,每个参数占用 2 字节
  • BF16(16 位脑浮点数):与 FP16 类似,但具有更大的指数范围,适用于深度学习
  • INT8(8 位整数):低精度整数,每个参数占用 1 字节
  • 量化格式4 位 或更低,用于特殊的量化算法,进一步减少内存占用。

对显存占用的影响

使用更低的精度可以显著减少模型的内存占用:

  • FP16 相对于 FP32:内存占用减半。
  • INT8 相对于 FP32:内存占用减少到原来的四分之一。

示例

对于一个 70B 参数的模型:

  • FP32 精度:280 GB 显存。
  • FP16/BF16 精度:140 GB 显存。
  • INT8 精度:70 GB 显存。

注意:实际显存占用还受到其他因素影响,如 CUDA 上下文、中间激活值和显存碎片等,因此不会严格按照理论值减半或减少四分之一。对于较小的模型,差距可能不会那么显著。

精度的权衡与选择

准确性 vs. 性能

  • 高精度(FP32)

    • 优点:更高的数值稳定性和模型准确性。
    • 缺点:占用更多显存,计算速度较慢
  • 低精度(FP16/INT8)

    • 优点:占用更少的显存,计算速度更快
    • 缺点:可能引入数值误差,影响模型性能。

何时选择何种精度

  • FP32

    • 适用于训练小型模型或对数值精度要求较高的任务。
  • FP16/BF16

    • 适用于训练大型模型,利用混合精度(Mixed Precision)来节省显存并加速计算。
  • INT8

    • 主要用于推理阶段,尤其是在显存资源有限的情况下部署超大模型

硬件兼容性

  • FP16 支持

    • 大多数现代 NVIDIA GPU(如 RTX 20 系列及以上)支持 FP16。
  • BF16 支持

    • 需要 NVIDIA A100、H100 等数据中心级别的 GPU,或最新的 RTX 40 系列 GPU。
  • INT8 支持

    • 需要特殊的库(如 bitsandbytes)和硬件支持。

实际应用中的精度技巧

使用 FP16 精度

在训练中启用混合精度

PyTorch 提供了 torch.cuda.amp 模块,可以方便地实现混合精度训练,加速计算并降低显存占用。

import torch
from torch import nn, optim
from torch.cuda.amp import GradScaler, autocast# MPS (Metal Performance Shaders) for Apple Silicon GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"model = nn.Sequential(...)  # 定义模型
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()for data, labels in dataloader:data = data.to(device)labels = labels.to(device)optimizer.zero_grad()with autocast():outputs = model(data)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

在推理中使用 FP16

model.half()  # 将模型转换为 FP16
model.to(device)  # 将模型移动到合适的设备
inputs = inputs.half().to('cuda')
outputs = model(inputs)

使用 BF16 精度

BF16(Brain Floating Point)具有与 FP32 相同的指数位数,减小了溢出和下溢的风险。

model = model.to(torch.bfloat16).to(device)
inputs = inputs.to(torch.bfloat16).to(device)
outputs = model(inputs)

注意:并非所有 GPU 都支持 BF16,你需要检查硬件兼容性。

使用 INT8 量化

安装 bitsandbytes

pip install bitsandbytes

使用 bitsandbytes 库实现 INT8 量化

from transformers import AutoModelForCausalLM
import bitsandbytes as bnbmodel_name = 'gpt2-large'model = AutoModelForCausalLM.from_pretrained(model_name,load_in_8bit=True,device_map='auto'
)

消除警告

在加载模型时,可能会遇到以下警告:

The load_in_4bit and load_in_8bit arguments are deprecated and will be removed in the future versions. Please, pass a BitsAndBytesConfig object in quantization_config argument instead.

解决方法:

使用 BitsAndBytesConfig 对象来配置量化参数。

from transformers import AutoModelForCausalLM, BitsAndBytesConfigbnb_config = BitsAndBytesConfig(load_in_8bit=True)model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb_config,device_map='auto'
)

实践示例

对比不同精度下的显存占用

加载模型并查看显存占用

以下代码示例展示了在不同精度下加载 gpt2-large 模型时的显存占用情况,并进行简单的推理测试。gpt2-large 大约有 812M(8.12 亿)= 0.812B 个参数。

import os
import gc
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnbdef load_model_and_measure_memory(precision, model_name, device):if precision == 'fp32':model = AutoModelForCausalLM.from_pretrained(model_name).to(device)elif precision == 'fp16':model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device)elif precision == 'int8':bnb_config = BitsAndBytesConfig(load_in_8bit=True)model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb_config,device_map='auto')else:raise ValueError("Unsupported precision")# 确保所有 CUDA 操作完成torch.cuda.synchronize()mem_allocated = torch.cuda.memory_allocated(device) / 1e9print(f"Precision: {precision}, Memory Allocated after loading model: {mem_allocated:.2f} GB")# 删除模型并清理缓存del modelgc.collect()torch.cuda.empty_cache()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'gpt2-large'for precision in ['fp32', 'fp16', 'int8']:print(f"\n--- Loading model with precision: {precision} ---")load_model_and_measure_memory(precision, model_name, device)

示例输出

--- Loading model with precision: fp32 ---
Precision: fp32, Memory Allocated after loading model: 3.21 GB--- Loading model with precision: fp16 ---
Precision: fp16, Memory Allocated after loading model: 1.60 GB--- Loading model with precision: int8 ---
Precision: int8, Memory Allocated after loading model: 0.89 GB

额外说明

  • torch.cuda.memory_allocated 仅测量由 PyTorch 当前进程分配的显存,不包括其他进程或系统预留的显存。

常见问题及解决方案

问题一:RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error

RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error (look up to see its traceback):
module ‘wandb.proto.wandb_internal_pb2’ has no attribute ‘Result’

解决方法

  • 卸载并重新安装 wandb

    pip uninstall wandb
    pip install wandb
    
  • 如果问题仍然存在,禁用 wandb

    import os
    os.environ["WANDB_DISABLED"] = "true"

问题二:TypeError: dispatch_model() got an unexpected keyword argument 'offload_index'

解决方法:

  • 检查 transformersaccelerate 库的版本:

    import transformers
    import accelerateprint(f"Transformers version: {transformers.__version__}")
    print(f"Accelerate version: {accelerate.__version__}")
    
  • 更新库:

    pip install --upgrade transformers accelerate
    

总结

现在你应该理解了模型参数与显存的关系,以及不同数值精度对显存和性能的影响,这不仅在实际应用中具有重要意义,也是面试中的常见考点,而且对于后续的学习同样很重要。毕竟看得懂代码在说什么,比当作黑箱要好得多。

最后的思考:

精度的降低意味着性能的妥协,在我过去的一些小型试验中,低精度下训练的性能还是一般都不如高精度。但,跑不跑的好是一回事,能不能跑又是另一回事,如果低显存能跑大模型,性能上的妥协也是完全可以接受的。

参考文献

  • PyTorch Mixed Precision Training
  • Transformers Documentation
  • bitsandbytes - Github

http://www.ppmy.cn/news/1525639.html

相关文章

Camera2 预览旋转方向、拍照、录像成像旋转

文章目录 前言一、思考问题二、基础补充、资源参考架构图了解Camera相关专栏零散知识了解部分相机源码参考,学习API使用,梳理流程,偏应用层Camera2 系统相关 三、核心问题:预览方向不对【图片、视频】、成像存储不对、拉伸问题预览…

类型参数传值问题

一、基本数据类型传参问题 public static void main(String[] args) throws Exception {Integer number null;method01(number);}private static void method01(int number){System.out.println("number " number);}Ps: 基于int基本数据类型传参的时候&#xff0c…

Linux操作系统入门(二)

完成了前篇所进行的VMware下载安装,并在其内配置了CentOS7的linux操作系统之后,我们得以正式进入了Linux的世界。 一.安装FinalShell 在本篇中,为了更好的在Windows系统上对虚拟机中的linux操作系统进行操作,我们需要下载一款新…

基于I2S的音频ADC_DAC的_FPGA的驱动

前言 这是博主自己原创的成果,如要转载或者引用,请标明出处,具体的视频讲解见我的bili视频讲解,先附链接 引出目的 课题项目需求做一个基于FPGA的相控扬声器后面进行数字滤波器的设计与实现后期FPGA算法的实现 整体模块框图 驱…

java重点学习-线程池的使用和项目案例

十一 线程池的使用场景 你们项目哪里用到了多线程 批量导入:使用了线程池CountDownLatch批量把数据库中的数据导入到了ES(任意)中,避免OOM数据汇总:调用多个接口来汇总数据,如果所有接口(或部分接口)的没有依赖关系,就可以使用线程池future来…

基于APISIX实现API网关案例分享

一、APISIX介绍 1、定义 Apache APISIX 是一个动态、实时、高性能的云原生 API 网关。它构建于 NGINX + ngx_lua 的技术基础之上,充分利用了 LuaJIT 所提供的强大性能。 2、软件架构 2.1、架构图 APISIX 主要分为两个部分: APISIX 核心:包括 Lua 插件、多语言插件运行时…

Opencv实现提取卡号(数字识别)

直接开始 实行方法 解析命令行参数:使用argparse库来解析命令行输入,确保用户提供了输入图像和模板图像的路径。 读取模板图像:使用cv2.imread()函数读取模板图像的路径,并显示原始图像。 图像预处理: 将图像转换为…

Java面试篇基础部分-Java中的集合类

Java集合是面试中经常被问到的一块内容,很多人在这个地方被面试官吊打。Java集合类被定义在java.util包中,主要有四种集合,分别是List、Queue、Set和Map,每种集合分类如下图所示 List集合 List是一种在开发中比较常用的集合类,作为有序的Collection的典范,分别有如下的…

ubuntu20.4安装Qt5.15.2

ubantu20.4镜像下载地址: https://releases.ubuntu.com/focal/ubuntu-20.04.6-desktop-amd64.iso Qt5.15.2下载地址: https://download.qt.io/official_releases/online_installers/ 安装步骤 1、进入地址后选择对应安装包,我这是ubuntu…

redis基本数据结构-string

文章目录 1. redis的string数据结构2. 常见的业务场景2.1 缓存功能案例讲解背景优势解决方案代码实现 2.2 计数器案例讲解背景优势解决方案代码实现 2.3 分布式锁案例讲解背景优势解决方案代码实现 2.4 限流案例讲解背景优势解决方案代码实现 2.5 共享session案例讲解背景优势解…

HarmonyOS开发之路由跳转

文章目录 一、路由跳转模式与实例1.router.pushUrl2.router.replaceUrl3.router.back 一、路由跳转模式与实例 跳转模式 有点类似于vue的路由跳转 router.pushUrl 保留路由栈,保留当前的页面;router.replaceUrl 销毁当前页面,跳转一个新的页…

Go语言现代web开发08 if和switch分支语句

if语句 If is the most common conditional statement in programming languages. If the result of the condition caculation is positive(true), the code inside if statement will be executed. In the next example, value a will be incremented if it is less than 10…

opencv学习:信用卡卡号识别

该代码用于从信用卡图像中自动识别和提取数字信息。该系统将识别信用卡类型,并输出信用卡上的数字序列。 1.创建命令行参数 数字模板 信用卡 # 创建命令行参数解析器 ap argparse.ArgumentParser() # 添加命令行参数 -i/--image,指定输入图像路径 ap.…

饿了么基于Flink+Paimon+StarRocks的实时湖仓探索

摘要:本文整理自饿了么大数据架构师、Apache Flink Contributor 王沛斌老师在8月3日 Streaming Lakehouse Meetup Online(Paimon x StarRocks,共话实时湖仓架构)上的分享。主要分为以下三个内容: 饿了么实时数仓演进之…

python-游戏自动化(一)(实战-自动刷视频点赞)

前提准备 什么是游戏自动化? 游戏自动化是指通过对游戏的界面结构的解析或界面图像的处理与识别,再模拟人工对软件进行的各种操作,从而实现自动化,达到解放双手,节约时间,提高效率的目标。 在本教程中&am…

房产销售系统开发:SpringBoot技术要点

摘 要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于房产销售系统当然也不能排除在外,随着网络技术的不断成熟,带动了房产销售系统,它彻底改变了过去传统的…

RDMA应用场景及效果

GPU Direct 参考:网络架构如何支持超万卡的大规模 AI 训练?| AICon_芯片与网络_InfoQ精选文章 GPU 网络的情况已经发生了很大变化。每个 GPU 都有自己的内部互联,例如 NVIDIA 的 A100 或 H800,它们内部的 NVLink 互联可以达到 6…

【网络安全】空字节绕过:URL回调+XSS+SQL绕WAF

未经许可,不得转载。 文章目录 空字节URL回调XSSSQL空字节 \0,也称为null字节,是一个值为零的特殊字符。在编程中,通常用来表示字符串的结束。攻击者可以利用null字节注入来绕过一些验证或过滤机制。 以下三个漏洞,空字节功不可没。 URL回调 密码重置功能,发起请求后…

如何找到UI5 Tooling-UI5命令

文章目录 UI5 Tooling第一步:首先找找到UI5 的官网如下:第二步:找到get started, 学习UI5 Demo第三步:开发环境--搭建安装UI5 命令行界面Global installation to have the command availableAdditional local install …

学习平台|基于java的移动学习平台系统小程序(源码+数据库+文档)

学习平台|学习平台系统|在线学习平台系统小程序 目录 基于java的移动学习平台系统小程序 一、前言 二、系统设计 三、系统功能设计 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍:✌️大厂码…