显存不够用?一种大模型加载时节约一半显存的方法

news/2024/10/21 7:38:19/

Loading huge PyTorch models with linear memory consumption

本文主要介绍了一种用于加载巨大模型权重时节约接近一半显存的方法

首先,创建一个模型:

import torch
from torch import nnclass BoringModel(nn.Sequential):def __init__(self):super().__init__()self.in_proj = nn.Linear(2, 10)self.stages = nn.Sequential(nn.Linear(10, 10),nn.Linear(10, 10))self.out_proj = nn.Linear(10, 2)

上述创建,模型占用 1x 显存, x是指模型的大小

model = BoringModel()
# model is now in memory

有些时候我们把模型保存到本地硬盘中

torch.save(model.state_dict(), "./checkpoint.pt")
# our models is now stored on disk

之后需要用到之前保存的模型(两倍显存消耗)

# we need to redefine the model
model = BoringModel()# 1x memory used
state_dict = torch.load("./checkpoint.pt")# 2x memory used -> both model and state_dict are in memory!!!
model.load_state_dict(state_dict)
# 1x memory used

我们需要两倍的显存来加载我们之前存储过的权重
如果我们有一个巨大的模型,这是有问题的,因为我们需要两倍的空闲RAM。例如,假设我们有16GB的RAM,而我们的模型使用10GB。加载它需要20GB,我们需要改变我们的策略。
Recently, PyTorch introduced the meta device. When you put a tensor to the meta device, only its metadata (e.g. shape) are stored, and its values are tossed away. Thus, no space is used.

meta例子

x = torch.tensor([1])
x

tensor([1])

x.to(torch.device("meta"))

tensor(…, device=‘meta’, size=(1,), dtype=torch.int64)

因此,我们可以通过这种方法使用一倍的显存消耗来加载我们的模型

  • 定义我们的模型 1x显存

  • 实例化到meta设备上 1x显存

  • 加载state_dict,1x显存

  • replace all empty parameters of our model with the values inside the state_dict 1x显存

我们首先需要弄清楚如何将所有模型的参数替换为加载的“state_dict”中的原始参数

Let’s create the load_state_dict_with_low_memory function.

from typing import Dictdef load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):# 通过把模型放到meta设备上来释放一半的显存model.to(torch.device("meta"))# 我们需要将state_dict中的每个键关联到一个子模块# we need to associate each key in state_dict to a submodule# 然后,迭代地使用' state_dict '中的值重新创建所有子模块的参数then, iteratively, re-creat all submodules' parameters with the values in `state_dict`pass
load_state_dict_with_low_memory(model, {})model.state_dict()
OrderedDict([('in_proj.weight', tensor(..., device='meta', size=(10, 2))),('in_proj.bias', tensor(..., device='meta', size=(10,))),('stages.0.weight', tensor(..., device='meta', size=(10, 10))),('stages.0.bias', tensor(..., device='meta', size=(10,))),('stages.1.weight', tensor(..., device='meta', size=(10, 10))),('stages.1.bias', tensor(..., device='meta', size=(10,))),('out_proj.weight', tensor(..., device='meta', size=(2, 10))),('out_proj.bias', tensor(..., device='meta', size=(2,)))])

模型现在是空的。

现在我们必须计算出来自state_dict的每个参数必须放入模型的哪个submodule of model中。一种方法是使用[key_in_state_dict] -> [submodule_in_module]创建一个字典。Now we have to figure out in which submodule of model each parameter from state_dict has to go. One way to do it is to create a dictionary with [key_in_state_dict] -> [submodule_in_module].

因此,我们知道我们必须将加载的state_dict中的值放在哪里。记住,一旦模型被放置在元设备中,它的所有权重都将被丢弃。
So we know where we have to place the values from the loaded state_dict. Remember, as soon as the model is placed inside the meta device, all its weights are tossed away.)

from typing import Dictdef get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:keys_to_submodule = {}# iterate all submodulesfor submodule_name, submodule in model.named_modules():# iterate all paramters in each submobulefor param_name, param in submodule.named_parameters():# param_name is organized as <name>.<subname>.<subsubname> ...# the more we go deep in the model, the less "subname"s we havesplitted_param_name = param_name.split('.')# if we have only one subname, then it means that we reach a "leaf" submodule, # we cannot go inside it anymore. This is the actual parameteris_leaf_param = len(splitted_param_name) == 1if is_leaf_param:# we recreate the correct keykey = f"{submodule_name}.{param_name}"# we associate this key with this submodulekeys_to_submodule[key] = submodulereturn keys_to_submodule
get_keys_to_submodule(model)

请添加图片描述
现在我们有办法知道哪个键对应’ model 的哪个submodule of model。让我们回到我们的load_state_dict_with_low_memory函数并使用来自state_dict的正确值将每个子模块的参数具体化

def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):# free up memory by placing the model in the `meta` devicemodel.to(torch.device("meta"))keys_to_submodule = get_keys_to_submodule(model)for key, submodule in keys_to_submodule.items():# get the valye from the state_dictval = state_dict[key]# we need to substitute the parameter inside submodule, # remember key is composed of <name>.<subname>.<subsubname># the actual submodule's parameter is stored inside the # last subname. If key is `in_proj.weight`, the correct field if `weight`param_name = key.split('.')[-1]param_dtype = getattr(submodule, param_name).dtypeval = val.to(param_dtype)# create a new parameternew_val = torch.nn.Parameter(val, requires_grad=False))setattr(submodule, param_name, new_val)
model.state_dict()

请添加图片描述

load_state_dict_with_low_memory(model, torch.load("checkpoint.pt"))
model.state_dict()

请添加图片描述
🎉 We have successfully loaded our checkpoint inside our model with linear memory consumption!


http://www.ppmy.cn/news/48343.html

相关文章

【Leetcode -剑指Offer 22.链表中倒数第k个结点 -203.移除链表元素】

Leetcode Leetcode -剑指Offer 22.链表中倒数第k个结点Leetcode -203.移除链表元素 Leetcode -剑指Offer 22.链表中倒数第k个结点 题目&#xff1a;输入一个链表&#xff0c;输出该链表中倒数第k个节点。为了符合大多数人的习惯&#xff0c;本题从1开始计数&#xff0c;即链表…

OSCP-Clyde(rabbitmq中间件、erlang服务4369、修改Payload、nmap提权)

目录 扫描 FTP erlang服务(4369) 提权 扫描 21/tcp open ftp vsftpd 3.0.3 | ftp-anon: Anonymous FTP login allowed (FTP code 230) | drwxr-xr-x 2 ftp ftp 4096 Apr 24 2020 PackageKit | drwxr-xr-x 5 ftp ftp 4096 Apr 24 2020 apache2 | drwxr-xr-x 5 ftp ftp 409…

云原生之在kubernetes集群下部署Mysql应用

云原生之在kubernetes集群下部署mysql应用 一、Mysql介绍二、kubernetes集群介绍1.k8s简介2.k8s架构图 三、本次实践介绍1.本次实践简介2.本次环境规划 三、检查本地k8s集群环境1.检查k8s各节点状态2.检查k8s版本3.检查k8s系统pod状态 四、编辑mysql.yaml文件五、创建mysql应用…

Redis分布式锁有哪些缺点?如何解决?

目录 一、死锁问题&#xff1a; 二、锁竞争问题&#xff1a; 三、时效性问题&#xff1a; 四、单点故障问题&#xff1a; 五、高并发量下锁抢占时间长的问题 一、死锁问题&#xff1a; 因为每个客户端在设置锁过期时间时可能出现网络延迟等原因&#xff0c;有可能出现某个…

五项热门技术领域和应用场景

介绍五种当下比较热门的技术&#xff0c;分别是人工智能、云计算、数据分析、微服务和区块链。每种技术都有自己的定义、子领域、应用场景和学习难度。这些技术都有着广阔的发展前景和市场需求&#xff0c;对于想要从事或了解这些领域的人来说&#xff0c;都是很有价值的知识。…

centos7安装nginx的三种方式~yum源,源码,Docker

目录 1.yum安装&#xff1a;Centos7源默认没有nginx 2.源码安装&#xff1a; 3.Docker安装&#xff1a; 1.yum安装&#xff1a;Centos7源默认没有nginx 配置yum源&#xff1a; wget -O /etc/yum.repos.d/epel.repo http://mirrors.aliyun.com/repo/epel-7.repo 查看nginx源&…

Vue中的路由导航

声明式路由导航 router官网-起步 声明式路由导航其实就是使用官方给的<router-link>路由导航标签直接进行路由跳转 <body> <div id"app"><!--<router-link>路由导航标签&#xff0c;用于找到path属性中url对应的组件&#xff0c;通过传入…

Spring的循环依赖

什么是循环依赖&#xff1f; 循环依赖其实就是循环引用&#xff0c;也就是两个或者两个以上的 bean 互相持有对方&#xff0c;最终形成闭环。比如 A 依赖于 B&#xff0c;B 依赖于 C&#xff0c;C 又依赖于 A。如下图&#xff1a; 注意&#xff0c;这里不是函数的循环调用&…