训练 Transfomer 模型的内存消耗计算

devtools/2024/9/25 5:27:11/

目录

  • model 内存
  • gradients 内存
  • activates 内存

经典图打底:

训练深度模型的内存消耗主要有以下几个部分:

  1. 存储模型可训练参数
  2. 存储梯度
  3. 存储反向传播中间变量,例如:

L = ( Y − Y ^ ) 2 Y ^ = X T W ∂ L ∂ W = − 2 ( Y − Y ^ ) ∂ Y ^ ∂ W = − 2 ( Y − X T W ) X \begin{aligned} L &= (Y - \hat Y)^2\\ \hat Y &= X^T W\\ \frac{\partial L}{\partial W}&= -2(Y-\hat Y) \frac{\partial \hat Y }{\partial W} = -2(Y- X^T W) X \end{aligned} LY^WL=(YY^)2=XTW=2(YY^)WY^=2(YXTW)X
这里面 X X X 就需要保存下来供反向传播时使用

下面具体的分析中需要用到每一层的具体运算张量,具体可以参考 Transfomer矩阵维度分析及MultiHead详解


model 内存

    """计算储存Transformer模型可训练参数所需的内存参数:- vocab_in_size: vocab_in大小- vocab_out_size: vocab_out大小- encoder_layers_num: 编码器层数- decoder_layers_num: 解码器层数- d_model: 编码器和解码器的隐藏层大小- num_head: 头的数量- embedding_size: 词嵌入大小- filter_size: 前馈子层的隐藏层大小- batch_size: 批大小- seq_len: 输入序列长度- bias: 是否加偏置项- include_pos_embedding: 位置编码是否单独包含可优化参数- dropout_rate: 例如: 0.1- dtype_size: 默认为4 (FP32),若是FP16,改为2返回:- 所需内存,以字节为单位。"""bias = bias * 1# 计算encoder embedding的参数内存消耗encoder_embedding_params = vocab_in_size * embedding_size# 计算 Encoder 的参数内存消耗# Multi-head Attention parameters: 3 * (d_model * d_model) + (d_model * d_model)# Layer normalization: d_model + d_model * bias# Feed-forward network parameters: d_model * filter_size + filter_size * d_modelattention_params = 4 * d_model * d_modellayer_norm_params = d_model + d_model * biasffn_params_params = 2 * d_model * filter_sizeencoder_params = (attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * encoder_layers_num# 计算decoder embedding的参数内存消耗decoder_embedding_params = vocab_out_size * embedding_size# 计算 Decoder 的参数内存消耗# Masked Multi-head Attention parameters: 4 * (d_model * d_model)# Multi-head Attention parameters: 4 * (d_model * d_model)decoder_params = (attention_params + layer_norm_params + attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * decoder_layers_num# 计算最后 output 层的参数内存消耗output_params = d_model * vocab_out_size# 计算储存模型可训练参数所需内存,考虑 dropout_rate(近似估算)model_memory = (encoder_embedding_params + encoder_params + decoder_embedding_params + decoder_params + output_params) * (1 + dropout_rate) * dtype_sizeif include_pos_embedding:model_memory += seq_len * d_model * 2 # encoder 和 decoder 各有一个 pos embedding

gradients 内存

这里除了 gradients 内存,还考虑了一些小项,例如 mask,优化器 等消耗的内存

def get_inputs_mem(batch_size, seq_len, dtype_size=8):"""计算Transformer模型输入数据的内存占用参数:- batch_size: 批大小- seq_len: 输入序列长度- dtype_size: 默认为8 (int64)返回:- 所需内存,以字节为单位。"""return batch_size * seq_len * dtype_size * 2  # 同时计算输入和输出# 计算attention中的mask的内存消耗# Mask: seq_len * seq_len for each attention blockmask_memory = seq_len * seq_len * (encoder_layers_num + decoder_layers_num*2) * dtype_size# 计算gradients消耗的内存, 训练过程中的梯度与模型参数的形状相同,因此梯度的内存大小也是 model_memorygrads_memory = model_memory# 计算优化器消耗的内存,此处以adam为例,对每一个可训练参数,需要储存一个一阶动量和一个二阶动量# 若使用的其他优化器,此处按需修改optimizer_memory = 2 * model_memory# 数据存储消耗的内存inputs_memory = get_inputs_mem(batch_size,seq_len)

activates 内存

    """计算中间结果(activates)的内存消耗,反向传播需要用到这些中间结果参数:- vocab_out_size: vocab_out大小- encoder_layers_num: 编码器层数- decoder_layers_num: 解码器层数- d_model: 编码器和解码器的隐藏层大小- num_head: 头的数量- filter_size: 前馈子层的隐藏层大小- batch_size: 批大小- seq_len: 输入序列长度- dtype_size: 默认为4 (FP32),若是FP16,改为2返回:- 所需内存,以字节为单位。"""# 由于各个layer的输入和输出size都是 batch_size * seq_len * d_model, 先计算出来后续使用N = batch_size * seq_len * d_model * dtype_size# 计算每层 attention 部分的中间结果内存消耗# 1.linear transformation: X*W_q = Q, X*W_k = K, X*W_v = V, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 X (只需储存一个,因为是同一个X)# 2.由于 Attention(Q,K,V) = softmax(QK^T/sqrt(d))V, 其中 QK^T 的张量为 [batch_size * num_head * seq_len * d_model/num_head] * [batch_size * num_head * d_model/num_head * seq_len] = [batch_size * num_head * seq_len * seq_len]# V 张量为 [batch_size * num_head * seq_len * d_model/num_head], 需要存储 Q, K, V, softmax(QK^T/sqrt(d))# 3.output linear transformation: Y = Attention(Q,K,V)*W_2, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 Attention(Q,K,V)linear_memory = Nsoftmax_memory = 3 * N + batch_size * num_head * seq_len * seq_len * dtype_sizeoutput_memory = Nattention_memory = linear_memory + softmax_memory + output_memory# 计算每层的 Layer normalization 的中间结果内存消耗, Layer normalization 输出张量为 batch_size * seq_len * d_modellayer_norm_memory = N# 计算每层的 FFN 部分的中间结果内存消耗# 1.第一层 linear transformation: X*W_1 = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * filter_size] = [batch_size * seq_len * filter_size], 需储存 X# 2.中间 Relu 连接: Y' = Relu(Y), 需储存 Y'# 3.第二层 linear transformation: Y'*W_2 = Z, 张量为 [batch_size * seq_len * filter_size] * [batch_size * filter_size * d_model] = [batch_size * seq_len * d_model], 需储存 Y'ffn_memory = N + 2 * batch_size * seq_len * filter_size * dtype_sizeencoder_memory = (attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * encoder_layers_numdecoder_memory = (attention_memory + layer_norm_memory + attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * decoder_layers_num# 计算 output 层的中间结果内存消耗# 1.output linear transformation: X*W = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * vocab_out_size] = [batch_size * seq_len * vocab_out_size], 需储存 X# 2.softmax(Y): 需储存 softmax(Y)output_memory = N + batch_size * seq_len * vocab_out_size * dtype_sizetotal_activates_memory = encoder_memory + decoder_memory + output_memory

将上述三个部分加总,就是训练 Transfomer 模型大概需要的内存消耗。

NOTE:

  1. 这里没有考虑混合精度训练,如果考虑混合精度训练,还需要在不同的部分,使用不同的 dtype_size
  2. 如果是GPT这种 decoder-only 或者 encoder-only 的模型,只需要 decoder_layers_num = 0,即可 (decoder-only 也是这样做的,因为decoder-only 中的 Masked Multi-head Attention 没有了,实际的参数情况和 encoder-only 是一样的)

Reference:
Transformer Memory Arithmetic: Understanding all the Bytes in nanoGPT
Formula to compute approximate memory requirements of transformer models
Transformer Math 101


http://www.ppmy.cn/devtools/94216.html

相关文章

mesh格式转换:glb转ply——使用Blender烘焙贴图到顶点色

1. 导入glb文件 选择shading后,选中物体,就能看到下面的节点树。 2. 创建顶点颜色 这个时候我们可以看到模型的顶点颜色是纯白色的。 2. 将贴图付给材质 原来: 现在: 3. 切换渲染器并烘焙顶点颜色 第三行选择CPU渲染或者GPU…

前端纯数组转树形结构

问题描述 前端需要处理后端返回的数据,展示如下。 解决方式 因为使用ProTable组件,那么数据只要携带children字段,就可以如上图展示。 方式一:后端返回数据的时候,直接封装好,如下: const…

Nginx异常关闭之中了挖矿病毒kswapd0

问题描述:系统突然无法访问了,登录服务器看了一下是因为Nginx服务关闭,重启后过了几天仍然异常关闭 系统:CentOS 7,Nginx 1.20 尝试解决过程:1、查询nginx/logs/error.log、系统日志,都没有查…

Springboot 级联json数据写入Mysql数据库

在Spring Boot应用程序中,将级联JSON数据写入MySQL数据库通常涉及使用JPA(Java Persistence API)和Hibernate进行实体关系映射。以下是一个完整的示例,包括如何定义实体、配置级联关系、处理JSON数据并将其保存到数据库中。 目录…

MySQL 存储引擎之InnoDB

InnoDB 存储引擎是一个非常重要的组件,它提供了许多高级特性,如事务支持、行级锁定、多版本并发控制 (MVCC) 和外键约束等。InnoDB 是 MySQL 5.5 及更高版本中的默认存储引擎,并且被广泛应用于需要高可靠性和并发性的应用程序中。 InnoDB 的…

汽车补光照明实验太阳光模拟器光源

汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节,它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果,以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…

uniapp3.0实现图片上传公用组件上传uni-file-picker,uni.uploadFile

用uniapp3.0的写法组合式api,setup形式封装一个图片上传公用组件,要求 1、使用uni-file-picker选择文件 2、uni.uploadFile上传图片 3、要能支持上传接口动态化 4、支持删除如片列表中已上传项 5、可以预览已上传列表图片 6、支持动态化限制图片格…

【c++】通过Privilege类来保护数据

简介: 我设计了一个类Privilege类来保护数据,它有效地通过控制访问性和可修改性来保护数据。不过,有几个小地方可以改进或注意,以确保代码的健壮性和易用性。 源码展示: #include <iostream> #include <stdexcept> // 包含 std::runtime_errorclass Privil…