什么是torchrun?

embedded/2024/11/13 9:33:56/

torchrun 是 PyTorch 用于分布式训练的命令行工具,旨在简化启动和管理分布式训练任务的过程。下面我将详细讲解 torchrun 的使用方法,并讨论它与分布式数据并行(Distributed Data Parallel, DDP)的区别。

torchrun_2">一、torchrun的使用方法

1. 安装

首先,确保你已经安装了 PyTorch 和 torch.distributed 模块。可以使用以下命令安装 PyTorch:

pip install torch
2. 配置环境

在进行分布式训练之前,你需要配置环境变量。通常需要设置以下环境变量:

  • MASTER_ADDR:主节点的 IP 地址。
  • MASTER_PORT:主节点的端口号。
  • WORLD_SIZE:参与训练的进程总数。
  • RANK:当前进程的排名(从0开始)。

例如,可以在终端中设置这些变量:

export MASTER_ADDR="localhost"
export MASTER_PORT=12355
export WORLD_SIZE=2
export RANK=0
3. 编写训练脚本

编写一个分布式训练脚本。以下是一个简单的示例:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDPdef setup(rank, world_size):dist.init_process_group("gloo", rank=rank, world_size=world_size)def cleanup():dist.destroy_process_group()def demo_basic(rank, world_size):print(f"Running basic DDP example on rank {rank}.")setup(rank, world_size)# Create model and move it to GPU with id rankmodel = torch.nn.Linear(10, 10).to(rank)ddp_model = DDP(model, device_ids=[rank])# Dummy input and targetinput = torch.randn(20, 10).to(rank)target = torch.randn(20, 10).to(rank)# Loss functionloss_fn = torch.nn.MSELoss()optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(input)loss_fn(outputs, target).backward()optimizer.step()cleanup()def main():world_size = 2mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)if __name__ == "__main__":main()
torchrun_71">4. 使用torchrun启动训练

可以使用 torchrun 命令启动分布式训练任务。假设你的脚本名为 train.py,可以使用以下命令启动训练:

torchrun --nproc_per_node=2 train.py

torchrunDDP_77">二、torchrun与DDP的区别

torchrun 是一个用于启动和管理分布式训练任务的工具,而 DDP 是 PyTorch 提供的用于实现分布式数据并行的模块。两者的关系可以概括如下:

  • torchrun:负责启动和管理分布式训练进程,简化了环境变量的配置和进程的启动。
  • DDP:负责在分布式训练中进行模型的同步和梯度的聚合。

使用 torchrun 可以简化分布式训练任务的启动,而 DDP 则是在分布式训练中实现具体的并行计算。通常情况下,二者是配合使用的:torchrun 启动训练任务,DDP 进行分布式训练。

示例总结

结合上述内容,我们可以看到,通过使用 torchrun,你可以轻松启动多个分布式训练进程,而无需手动配置复杂的环境变量。这极大地简化了分布式训练的管理和调试过程。而 DDP 则确保了在多个进程之间同步模型参数和梯度,使得训练过程能够高效地并行化。

希望这些信息能帮助你更好地理解和使用 torchrun 以及 PyTorch 的分布式数据并行。若有任何问题或进一步的需求,请随时提问。


http://www.ppmy.cn/embedded/53975.html

相关文章

个人网站搭建-步骤(持续更新)

域名申请 域名备案 域名解析 服务器购买 端口转发 Nginx要在Linux上配置Nginx进行接口转发,您可以按照以下步骤进行操作: 安装Nginx(如果尚未安装): 使用包管理工具(如apt, yum, dnf, 或zypper&#x…

IOS Swift 从入门到精通:从 JSON 文件加载数据

文章目录 常见问题解答数据模型JSON 数据验证 JSON解码 JSON编写 FAQRow 代码添加状态栏背景模糊将内容添加到 FAQView常见问题解答数据模型 此 FAQ 模型符合Decodable,因为我们需要将 JSON 数据解码为 SwiftUI 数据。它还将符合 Identifiable ,因此我们稍后可以在 ForEach …

说一说ABAP CDS View的发展历史与特性

1. 背景 随着SAP Fiori应用程序的兴起,SAP领域的小伙伴接触和使用ABAP CDS View的机会也是越来越多。今天,让我们花些时间,一起在了解下这项技术的设计初衷和发展历史。 2. 设计初衷 说起ABAP CDS View,就不得不提及SAP HANA。…

【LeetCode面试经典150题】117. 填充每个节点的下一个右侧节点指针 II

一、题目 117. 填充每个节点的下一个右侧节点指针 II - 力扣(LeetCode) 给定一个二叉树: struct Node {int val;Node *left;Node *right;Node *next; } 填充它的每个 next 指针,让这个指针指向其下一个右侧节点。如果找不到下一个…

使用 Reqable 在 MuMu 模拟器进行App抓包(https)

1、为什么要抓包? 用开发手机应用时,查看接口数据不能像在浏览器中可以直接通过network查看,只能借助抓包工具来抓包,还有一些线上应用我们也只能通过抓包来排查具体的问题。 2、抓包工具 实现抓包,需要一个抓包工具…

Web3新视野:Lumoz节点的潜力与收益解读

摘要:低估值、高回报、无条件退款80%...... Lumoz正通过其 zkVerifier 节点销售活动,引领一场ZK计算革命。 长期以来,加密市场以其独特的波动性和增长潜力,持续吸引着全球投资者的目光。而历史数据表明,市场往往在一年…

Redis-主从复制-测试主从模式下的读写操作

文章目录 1、在主机6379写入数据2、在从机6380上写数据报错3、从机只能读数据,不能写数据 1、在主机6379写入数据 127.0.0.1:6379> keys * (empty array) 127.0.0.1:6379> set uname jim OK 127.0.0.1:6379> get uname "jim" 127.0.0.1:6379>…

【React】Axios请求头注入token

业务背景: Token作为用户的数据标识,在接口层面起到了接口权限控制的作用,也就是说后端有很多接口都需要通过查看当前请求头信息中是否含有token数据,来决定是否正常返回数据 // 添加请求拦截器 request.interceptors.request.use(config …