Pytorch使用教程(12)-如何进行并行训练?

ops/2025/1/22 8:19:40/

在使用GPU训练大模型时,往往会面临单卡显存不足的情况。这时,通过多卡并行的形式来扩大显存是一个有效的解决方案。PyTorch主要提供了两个类来实现多卡并行:数据并行torch.nn.DataParallel(DP)和模型并行torch.nn.DistributedDataParallel(DDP)。本文将详细介绍这两种方法。

一、数据并行(torch.nn.DataParallel)

  1. 基本原理
    数据并行是一种简单的多GPU并行训练方式。它通过多线程的方式,将输入数据分割成多个部分,每个部分在不同的GPU上并行处理,最后将所有GPU的输出结果汇总,计算损失和梯度,更新模型参数。
    在这里插入图片描述

  2. 使用方法
    使用torch.nn.DataParallel非常简单,只需要一行代码就可以实现。以下是一个示例:

python">import torch
import torch.nn as nn# 检查是否有多个GPU可用
if torch.cuda.device_count() > 1:print("Let's use", torch.cuda.device_count(), "GPUs!")# 将模型转换为DataParallel对象model = nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
  1. 优缺点
    ‌优点‌:代码简单,易于使用,对小白比较友好。
    ‌缺点‌:GPU会出现负载不均衡的问题,一个GPU可能占用了大部分负载,而其他GPU却负载较轻,导致显存使用不平衡。

二、模型并行(torch.nn.DistributedDataParallel)

  1. 基本原理
    torch.nn.DistributedDataParallel(DDP)是一种真正的多进程并行训练方式。每个进程对应一个独立的训练过程,且只对梯度等少量数据进行信息交换。每个进程包含独立的解释器和GIL(全局解释器锁),因此可以充分利用多GPU的优势,实现更高效的并行训练。
    在这里插入图片描述

  2. 使用方法

    使用torch.nn.DistributedDataParallel需要进行一些额外的配置,包括初始化GPU通信方式、设置随机种子点、使用DistributedSampler分配数据等。以下是一个详细的示例:

初始化环境

python">import torch
import torch.distributed as dist
import argparsedef parse():parser = argparse.ArgumentParser()parser.add_argument('--local_rank', type=int, default=0)args = parser.parse_args()return argsdef main():args = parse()torch.cuda.set_device(args.local_rank)dist.init_process_group('nccl', init_method='env://')device = torch.device(f'cuda:{args.local_rank}')

设置随机种子点

python">import numpy as np# 固定随机种子点
seed = np.random.randint(1, 10000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)使用DistributedSampler分配数据
python
Copy Code
from torch.utils.data.distributed import DistributedSamplertrain_dataset = ...  # 你的数据集
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, sampler=train_sampler
)

初始化模型

python">model = mymodel().to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])训练循环
python
Copy Code
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()for ep in range(total_epoch):train_sampler.set_epoch(ep)for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()
  1. 优缺点
  • 优点‌:每个进程对应一个独立的训练过程,显存使用更均衡,性能更优。
  • 缺点‌:代码相对复杂,需要进行一些额外的配置。

三、对比与选择

  1. 对比
特点torch.nn.DataParalleltorch.nn.DistributedDataParallel
并行方式多线程多进程
显存使用可能不均衡更均衡
性能一般更优
代码复杂度简单复杂
  1. 选择建议
  • 对于初学者或快速实验,可以选择torch.nn.DataParallel,因为它代码简单,易于使用。
  • 对于需要高效并行训练的场景,建议选择torch.nn.DistributedDataParallel,因为它可以充分利用多GPU的优势,实现更高效的训练。

四、小结

通过本文的介绍,相信读者已经对PyTorch的多GPU并行训练有了更深入的了解。在实际应用中,可以根据模型的复杂性和数据的大小选择合适的并行训练方式,并调整batch size和学习率等参数以优化模型的性能。希望这篇文章能帮助你掌握PyTorch的多GPU并行训练技术。


http://www.ppmy.cn/ops/152144.html

相关文章

基于注解实现去重表消息防止重复消费

基于注解实现去重表消息防止重复消费 1. 背景/问题 在分布式系统中,消息队列(如RocketMQ、Kafka)的 消息重复消费 是常见问题,主要原因包括: 网络抖动:生产者或消费者因网络不稳定触发消息重发。消费者超…

深度学习基础--LSTM学习笔记(李沐《动手学习深度学习》)

前言 LSTM是RNN模型的升级版,神经网络模型较为复杂,这里是学习笔记的记录;LSTM比较复杂,可以先看: 深度学习基础–一文搞懂RNN 深度学习基础–GRU学习笔记(李沐《动手学习深度学习》) RNN:RNN讲解参考&am…

「2024 博客之星」自研Java框架 Sunrays-Framework 使用教程

文章目录 0.序言我的成长历程遇到挫折,陷入低谷重拾信心,迎接未来开源与分享我为何如此看重这次评选最后的心声 1.概述1.主要功能2.相关链接 2.系统要求构建工具框架和语言数据库与缓存消息队列与对象存储 3.快速入门0.配置Maven中央仓库1.打开settings.…

医院管理系统小程序设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

资料03:【TODOS案例】微信小程序开发bilibili

样式 抽象数据类型 页面数据绑定 事件传参

2025美赛数学建模B题思路+模型+代码+论文

2025美赛数学建模A题B题C题D题E题思路模型代码(1.24第一时间更新,更新见文末名片) 论文数学建模感想 纪念逝去的大学数学建模:两次校赛,两次国赛,两次美赛,一次电工杯。从大一下学期组队到现在…

经验收录/用复盘的心态去学习

1.日拱一卒,想法积极。每次解决一点眼前的现实问题,长远来看是最高效的方法,一开始目标太远大,反而增加负担,在能力不够时想得太多反而会不愿意努力。 2.摆脱之前的思路。养成批判性思维,旁观者视角&#…

Linux C\C++方式下的文件I/O编程

【图书推荐】《Linux C与C一线开发实践(第2版)》_linux c与c一线开发实践pdf-CSDN博客 《Linux C与C一线开发实践(第2版)(Linux技术丛书)》(朱文伟,李建英)【摘要 书评 试读】- 京东图书 Lin…