从零开始训练一个大语言模型需要多少天?

server/2024/11/13 18:07:03/

一,前言

在AI领域,训练一个大型语言模型(LLM)是一个耗时且复杂的过程。几乎每个做大型语言模型(LLM)训练的人都会被问到:“从零开始,训练大语言模型需要多久和花多少钱?”虽然网上有很多关于训练技巧和模型评估的文章,但很少有直接告诉你如何估算训练时间和成本的。

前面分享了一些关于大模型/本地知识库的安装部署方法,无需编写代码,即可使用Ollama+AnythingLLM搭建企业私有知识库,或者,三步完成Llama3.2在算力魔方的INT4量化和部署...

本篇文章就教你一个简单的方法,帮你快速估算基于大语言模型权重大小、数据量以及可用GPU算力训练大语言模型所需的时间和成本。

二,估算方法

训练模型时,处理数据和更新模型参数需要大量的计算,我们用浮点运算次数(FLOPs)来表示。首先,我们要估算处理一个token所需的FLOPs,包括前向传递和反向传递两个部分。

  • 前向传递:

每个token的前向传递涉及的加乘操作数大约为:

FLOPsforward = 2 x N²+2 x N x Dmodel

这里N表示模型的参数量,Dmodel是模型的维度。系数2来源于矩阵乘法中的累加操作。

  • 反向传递:

大约需要前向传递的两倍计算量,因为要计算权重和激活值的梯度。

FLOPsbackward =(2 x N²+2 x N x Dmodel)x 2

  • 所以,一个token总的计算量大概是前向传递的三倍。因此,每个训练token的浮点运算可以估算为:

FLOPstotal =(2 x N²+2 x N x Dmodel)x 3

三,GPU性能

现在大多数模型都是用GPU来训练的。不同的GPU有不同的性能,比如NVIDIA的H100、A100或V100。每个GPU的性能可以用每秒浮点运算次数(FLOPS)来衡量。不过,实际训练时,由于多GPU之间的通信等因素,实际性能可能达不到理论上的最高值。

GPU Model

Peak FLOPS (FP32)

H100

67 TFLOPS

A100

19.5 TFLOPS

V100

14 TFLOPS

一个重要的概念是模型FLOPS利用率(MFU),它反映了实际计算效率与理论最大值的比例。通常情况下,随着GPU数量的增加,MFU会下降。LLaMA 3的研究者们用16,000个GPU训练模型时,每个GPU的实际效率为380 teraflops,MFU为38%。

四,实际案例

1,l Llama 3 405B 参数模型

LLaMA 3.1(405B参数)是在15.6万亿token的数据集上训练的。训练这样一个规模的模型所需的总FLOPs可以通过以下方式计算:

  • 模型大小 N = 405B

  • 数据集大小 P = 15.6T

模型使用了16,000个H100 GPU进行训练。据了解,平均吞吐量为每个GPU 400 teraflops。这意味着训练基础设施可以提供的总吞吐量为:

TotalThroughput

= 400TFLOPs/GPU×16,000GPUs

= 6.4ExaFLOPs

最后,通过将所需的总FLOPs除以可用吞吐量,并将结果转换为天数(因为我们真正关心的是训练天数),我们可以得到训练时间。

3.8 x 1025FLOPs ÷ 6.4 x1018FLOPs/秒 = 61

2,成本估算

训练模型不仅耗时,还非常昂贵。以LLaMA 3.1为例,如果一个H100 GPU每小时的费用是2美元,那么用16,000个H100训练这个模型的总成本大约为2 x 24 x 61 x 16,000 = 46,848,000美元。

五,总结

训练大型语言模型是一项技术复杂且资金密集的任务。从零开始,把一个LLaMA 3.1(405B参数)的模型在15.6万亿token数据集上训练出来,大约需要花费61天(假设没有训练中断)和46,848,000美元(仅估算GPU租金、数据集制作费用和研发人力成本未计入),你算对了吗?


http://www.ppmy.cn/server/140844.html

相关文章

深度学习-神经网络基础-激活函数与参数初始化(weight, bias)

一. 神经网络介绍 神经网络概念 神经元构建 神经网络 人工神经网络是一种模仿生物神经网络结构和功能的计算模型, 由神经元构成 将神经元串联起来 -> 神经网络 输入层: 数据 输出层: 目标(加权和) 隐藏层: 加权和 激活 全连接 第N层的每个神经元和第N-1层的所有神经元…

Python学习从0到1 day27 第三阶段 Spark ③ 数据计算 Ⅱ

目录 一、Filter方法 功能 语法 代码 总结 filter算子 二、distinct方法 功能 语法 代码 总结 distinct算子 三、SortBy方法 功能 语法 代码 总结 sortBy算子 四、数据计算练习 需求: 解答 总结 去重函数: 过滤函数: 转换函数: 排…

10. java基础知识(下)

文章目录 一、一带而过二、字符串类型String1. 简单了解2. 关于结束符\03. 自动类型转换与强制类型转换 三、API文档与import导包1. API文档2. import导包 四、java中的数组1. 创建2. 遍历3. 补充4. Arrays类① 简单介绍② 练习 五、方法的重载六、规范约束七、内容出处 一、一…

订单分库分表

一、引言 在当今互联网时代,随着电商、金融等行业的快速发展,订单数量呈爆炸式增长。传统的单一数据库存储订单信息的方式面临着巨大的挑战,如数据存储容量有限、查询性能下降、数据备份和恢复困难等。为了解决这些问题,分库分表技…

讨论一个mysql事务问题

最近在阅读一篇关于隔离级别的文章,文章中提到了一种场景,我们下面来分析一下。 文章目录 1、实验环境2、两个实验的语句执行顺序3、关于start transaction和start transaction with consistent snapshot4、实验结果解释4.1、实验14.2、实验24.3、调整实…

torch.full函数介绍

torch.full 是 PyTorch 中用于创建一个具有指定形状、填充值和数据类型的张量的函数。它非常适用于需要初始化特定数值的张量的情况,比如将所有元素填充为一个常量值。 函数定义 torch.full(size, fill_value, *, dtypeNone, layouttorch.strided, deviceNone, re…

多平台编包动态引入依赖的解决方案

最近开发时遇到了这样的需求,A 平台需要引入一个 video.js,B 平台却是不需要的,那么面向 B 平台打包的时候把依赖装进去自然就不大合适。最好的方法是动态引入依赖,根据平台来判断要不要引入 动态引入依赖 很快啊,动…

剑指offer第九天

1.数组中只出现一次的两个数 #include <vector> class Solution { public:vector<int> FindNumsAppearOnce(vector<int>& nums) {// write code hereint v 0;for(int&x:nums)v^x;int cnt 0;while(v){v>>1;cnt;}int a 0,b 0;for(int&x…