大模型训练框架DeepSpeed使用入门(1): 训练设置

server/2024/9/22 14:44:32/

文章目录

  • 一、安装
  • 二、训练设置
    • Step1 第一步参数解析
    • Step2 初始化后端
    • Step3 训练初始化
  • 三、训练代码展示


官方文档直接抄过来,留个笔记。
https://deepspeed.readthedocs.io/en/latest/initialize.html

使用案例来自:
https://github.com/OvJat/DeepSpeedTutorial


大模型训练的痛点是模型参数过大,动辄上百亿,如果单靠单个GPU来完成训练基本不可能。所以需要多卡或者分布式训练来完成这项工作。

DeepSpeed是由Microsoft提供的分布式训练工具,旨在支持更大规模的模型和提供更多的优化策略和工具。对于更大模型的训练来说,DeepSpeed提供了更多策略,例如:Zero、Offload等。

本文简单介绍下如何使用DeepSpeed


一、安装

pip install deepspeed

二、训练设置

Step1 第一步参数解析

DeepSpeed 使用 argparse 来应用控制台的设置,使用

deepspeed.add_config_arguments()

可以将DeepSpeed内置的参数增加到我们自己的应用参数解析中。

parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()

Step2 初始化后端

与Step3中的 deepspeed.initialize() 不同,
直接调用即可。
一般发生在以下场景

when using model parallelism, pipeline parallelism, or certain data loader scenarios.

在Step3的initialize前,进行调用

deepspeed.init_distributed()

Step3 训练初始化

首先调用 deepspeed.initialize() 进行初始化,是整个调用DeepSpeed训练的入口。
调用后,如果分布式后端没有被初始化后,此时会初始化分布式后端。
使用案例:

model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args,model=net,model_parameters=net.parameters(),training_data=ds)

API如下:

def initialize(args=None,model: torch.nn.Module = None,optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,model_parameters: Optional[torch.nn.Module] = None,training_data: Optional[torch.utils.data.Dataset] = None,lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,mpu=None,dist_init_required: Optional[bool] = None,collate_fn=None,config=None,config_params=None):"""Initialize the DeepSpeed Engine.Arguments:args: an object containing local_rank and deepspeed_config fields.This is optional if `config` is passed.model: Required: nn.module class before apply any wrappersoptimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.This overrides any optimizer definition in the DeepSpeed json config.model_parameters: Optional: An iterable of torch.Tensors or dicts.Specifies what Tensors should be optimized.training_data: Optional: Dataset of type torch.utils.data.Datasetlr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methodsdistributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed trainingmpu: Optional: A model parallelism unit object that implementsget_{model,data}_parallel_{rank,group,world_size}()dist_init_required: Optional: None will auto-initialize torch distributed if needed,otherwise the user can force it to be initialized or not via boolean.collate_fn: Optional: Merges a list of samples to form amini-batch of Tensor(s).  Used when using batched loading from amap-style dataset.config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed configas an argument instead, as a path or a dictionary.config_params: Optional: Same as `config`, kept for backwards compatibility.Returns:A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``* ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.* ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or ifoptimizer is specified in json config else ``None``.* ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,otherwise ``None``.* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, orif ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``."""

三、训练代码展示

def parse_arguments():import argparseparser = argparse.ArgumentParser(description='deepspeed training script.')parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')# Include DeepSpeed configuration argumentsparser = deepspeed.add_config_arguments(parser)args = parser.parse_args()return argsdef train():args = parse_arguments()# init distributeddeepspeed.init_distributed()# init modelmodel = MyClassifier(3, 100, ch_multi=128)# init datasetds = MyDataset((3, 512, 512), 100, sample_count=int(1e6))# init engineengine, optimizer, training_dataloader, lr_scheduler = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters(),training_data=ds,# config=deepspeed_config,)# load checkpointengine.load_checkpoint("./data/checkpoints/MyClassifier/")# trainlast_time = time.time()loss_list = []echo_interval = 10engine.train()for step, (xx, yy) in enumerate(training_dataloader):step += 1xx = xx.to(device=engine.device, dtype=torch.float16)yy = yy.to(device=engine.device, dtype=torch.long).reshape(-1)outputs = engine(xx)loss = tnf.cross_entropy(outputs, yy)engine.backward(loss)engine.step()loss_list.append(loss.detach().cpu().numpy())if step % echo_interval == 0:loss_avg = np.mean(loss_list[-echo_interval:])used_time = time.time() - last_timetime_p_step = used_time / echo_intervalif args.local_rank == 0:logging.info("[Train Step] Step:{:10d}  Loss:{:8.4f} | Time/Batch: {:6.4f}s",step, loss_avg, time_p_step,)last_time = time.time()# save checkpointengine.save_checkpoint("./data/checkpoints/MyClassifier/")

最后~

码字不易~~

独乐不如众乐~~

如有帮助,欢迎点赞+收藏~~



http://www.ppmy.cn/server/41509.html

相关文章

CVHub | CVPR 2024 | 英伟达发布新一代视觉基础模型: AM-RADIO = CLIP + DINOv2 + SAM

本文来源公众号“CVHub”,仅用于学术分享,侵权删,干货满满。 原文链接:CVPR 2024 | 英伟达发布新一代视觉基础模型: AM-RADIO CLIP DINOv2 SAM 标题:《AM-RADIO: Agglomerative Vision Foundation Model Reduce Al…

Charger之二输入电压动态电源原理(VIN-DPM)

主要内容 Charger的VIN-DPM 前篇内容:电池管理IC(Charger)了解一下? 领资料:点下方↓名片关注回复:粉丝群 正文 一、 VIN-DPM概念 VIN-DPM是指输入电压动态电源管理(Input voltage dynamic…

【瑞萨RA6M3】2. UART 实验

https://blog.csdn.net/qq_35181236/article/details/132789258 使用 uart9 配置 打印 void hal_entry(void) {/* TODO: add your own code here */fsp_err_t err;uint8_t c;/* 配置串口 */err g_uart9.p_api->open(g_uart9.p_ctrl, g_uart9.p_cfg);while (1){g_uart9.…

【C++】priority_queues(优先级队列)和反向迭代器适配器的实现

目录 一、 priority_queue1.priority_queue的介绍2.priority_queue的使用2.1、接口使用说明2.2、优先级队列的使用样例 3.priority_queue的底层实现3.1、库里面关于priority_queue的定义3.2、仿函数1.什么是仿函数?2.仿函数样例 3.3、实现优先级队列1. 1.0版本的实现…

用 JavaScript 计算 SHA-256 hash值

SHA-256算法是一个广泛使用的散列函数,它产生256位的hash值。它用于许多安全应用程序和协议,包括 TLS 和 SSL、 SSH、 PGP 和比特币。 在 JavaScript 中计算 SHA-256 hash值使用原生 API 很容易,但是浏览器和 Node.js 之间有一些区别。由于浏…

什么是web3D?应用场景有哪些?如何实现web3D展示?

Web3D是一种将3D技术与网络技术完美结合的全新领域,它可以实现将数字化的3D模型直接在网络浏览器上运行,从而实现在线交互式的浏览和操作。 Web3D通过将多媒体技术、3D技术、信息网络技术、计算机技术等多种技术融合在一起,实现了它在网络上…

Vulstack红队评估(一)

文章目录 一、环境搭建1、网络拓扑2、web服务器(win7)配置3、域控(winserver2008)配置4、域内机器(windows 2003)配置5、调试网络是否通常 二、web渗透1、信息搜集2、端口扫描3、目录扫描4、弱口令5、phpmyadmin getshell日志gets…

强化训练:day9(添加逗号、跳台阶、扑克牌顺子)

文章目录 前言1. 添加逗号1.1 题目描述2.2 解题思路2.3 代码实现 2. 跳台阶2.1 题目描述2.2 解题思路2.3 代码实现 3. 扑克牌顺子3.1 题目描述3.2 解题思路3.3 代码实现 总结 前言 1. 添加逗号   2. 跳台阶   3. 扑克牌顺子 1. 添加逗号 1.1 题目描述 2.2 解题思路 我的写…