参考:https://blog.csdn.net/python122_/article/details/141494273
一个含有1G参数的模型,如果每一个参数都是32bit(4byte),那么直接加载模型就会占用4x1G的显存。
常见的几种精度类型:从一次面试搞懂 FP16、BF16、TF32、FP32
参考:https://zhuanlan.zhihu.com/p/676509123
混合精度训练:
按照训练运行的逻辑来讲:
Step1:优化器会先备份一份FP32精度的模型权重,初始化好FP32精度的一阶和二阶动量(用于更新权重)。
Step2:开辟一块新的存储空间,将FP32精度的模型权重转换为FP16精度的模型权重。
Step3:运行forward和backward,产生的梯度和激活值都用FP16精度存储。
Step4:优化器利用FP16的梯度和FP32精度的一阶和二阶动量去更新备份的FP32的模型权重。
Step5:重复Step2到Step4训练,直到模型收敛。
我们可以看到训练过程中显存主要被用在四个模块上:
模型权重本身(FP32+FP16)
梯度(FP16)
优化器(FP32)
激活值(FP16)
对于llama3.1 8B模型,FP32和BF16混合精度训练,用的是AdamW优化器,请问模型训练时占用显存大概为多少?
解:
模型参数:16(BF16) + 32(PF32)= 48G
梯度参数:16(BF16)= 16G
优化器参数:32(PF32) + 32(PF32)= 64G
不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G