使用猴子补丁对pytorch的分布式接口进行插桩

ops/2024/11/27 20:42:50/

训练脚本:

python">from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import distributed_patch# 设置 NCCL 日志环境变量
'''
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"  # 或者 COLL
os.environ["NCCL_LOG_FILE"] = "nccl_log.txt"# 运行 PyTorch 分布式代码
'''class Net(nn.Module):  # 模型定义def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.seq = nn.Sequential(nn.Linear(28 * 28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10))def forward(self, x):x = self.flatten(x)return self.seq(x)def main():dist.init_process_group(backend='nccl')  # 【集合通讯】其他进程连master,大家互认rank = dist.get_rank()world_size = dist.get_world_size()device_name = f'cuda:{rank}'checkpoint = None  # 各自加载checkpointtry:checkpoint = torch.load('checkpoint.pth', map_location='cpu')  # checkpoint是cuda:0保存的,加载默认会读到cuda:0,所以明确指定给cpuexcept:passmodel = Net().to(device_name)if checkpoint and rank == 0:  # rank0恢复模型参数model.load_state_dict(checkpoint['model'])model = DDP(model)  # 【集合通讯】rank0广播参数给其他进程optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # model参数一致,则optim会保证其初始状态一致if checkpoint:optimizer.load_state_dict(checkpoint['optimizer'])  # 各自加载checkpointtrain_dataset = MNIST(root='./data', download=True, transform=ToTensor(), train=True)  # 各自加载datasetsampler = DistributedSampler(train_dataset)  # 指派子集给各进程train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=sampler, persistent_workers=True, num_workers=2)val_dataset = MNIST(root='./data', download=True, transform=ToTensor(), train=False)val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, persistent_workers=True, num_workers=2)for epoch in range(20):sampler.set_epoch(epoch)  # 【集合通讯】生成随机种子,rank0广播给其他进程model.train()for x, y in train_dataloader:x, y = x.to(device_name), y.to(device_name)pred_y = model(x)  # 【集合通讯】rank0广播model buffer给其他进程loss = F.cross_entropy(pred_y, y)optimizer.zero_grad()loss.backward()  # 【集合通讯】每个参数的梯度做all reduce(每个进程会收到其他进程的梯度,并求平均)optimizer.step()dist.reduce(loss, dst=0)  # 【集合通讯】rank0汇总其他进程的lossif rank == 0:train_avg_loss = loss.item() / world_size# evaluateraw_model = model.moduleval_loss = 0with torch.no_grad():for x, y in val_dataloader:x, y = x.to(device_name), y.to(device_name)pred_y = raw_model(x)loss = F.cross_entropy(pred_y, y)val_loss += loss.item()val_avg_loss = val_loss / len(val_dataloader)print(f'train_loss:{train_avg_loss} val_loss:{val_avg_loss}')# checkpointtorch.save({'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}, '.checkpoint.pth')os.replace('.checkpoint.pth', 'checkpoint.pth')dist.barrier()  # 【集合通讯】等待rank0跑完evalif __name__ == '__main__':main()# torchrun --nproc_per_node 1 pytorch_dis_gpu.py

插桩脚本:

python">import torch.distributed as dist# 保存原始函数引用
original_functions = {"init_process_group": dist.init_process_group,"all_reduce": dist.all_reduce,"reduce": dist.reduce,"broadcast": dist.broadcast,"barrier": dist.barrier,"get_rank": dist.get_rank,"get_world_size": dist.get_world_size
}# 插桩函数
def patched_init_process_group(*args, **kwargs):print("[distributed] init_process_group called")return original_functions["init_process_group"](*args, **kwargs)def patched_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):print("[distributed] all_reduce called")return original_functions["all_reduce"](tensor, op, group, async_op)def patched_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False):print("[distributed] reduce called")return original_functions["reduce"](tensor, dst, op, group, async_op)def patched_broadcast(tensor, src, group=None, async_op=False):print("[distributed] broadcast called")return original_functions["broadcast"](tensor, src, group, async_op)def patched_barrier(*args, **kwargs):print("[distributed] barrier called")return original_functions["barrier"](*args, **kwargs)def patched_get_rank(*args, **kwargs):print("[distributed] get_rank called")return original_functions["get_rank"](*args, **kwargs)def patched_get_world_size(*args, **kwargs):print("[distributed] get_world_size called")return original_functions["get_world_size"](*args, **kwargs)# 替换分布式接口函数为插桩版本
dist.init_process_group = patched_init_process_group
dist.all_reduce = patched_all_reduce
dist.reduce = patched_reduce
dist.broadcast = patched_broadcast
dist.barrier = patched_barrier
dist.get_rank = patched_get_rank
dist.get_world_size = patched_get_world_size


http://www.ppmy.cn/ops/137165.html

相关文章

RL78/G15 Fast Prototyping Board Arduino IDE 平台开发过程

这是一篇基于RL78/G15 Fast Prototyping Board的Arduino IDE开发记录 RL78/G15 Fast Prototyping Board硬件简介(背景)基础测试(方法说明/操作说明)开发环境搭建(方法说明/操作说明代码结果)Arduino IDE RL…

Java使用replaceAll替换时不使用正则表达式

前言 public String replaceAll(String regex, String replacement) {return Pattern.compile(regex).matcher(this).replaceAll(replacement);}在使用String.replaceAll() 方法时,由于入参时regex ,而入参刚好是正则表达式的字符该怎么办?我…

等保测评讲解:安全管理中心

在数字化转型的背景下,网络安全的重要性愈发凸显,而作为中国边疆大省的黑龙江,其网络安全建设更是不可忽视。等保测评,即信息安全等级保护测评,是确保信息系统安全的关键环节。本文将详细讲解黑龙江等保测评中的安全管…

力扣第 66 题 “加一”

题目描述 给定一个由 非负整数组成的非空数组,表示一个整数。在该整数的基础上加一。 最高位数字在数组的首位,数组中每个元素只存储单个数字。 你可以假设除了整数 0 之外,这个整数不会以零开头。 示例 1: 输入: digits [1,2,3] 输出:…

5.算法移植第六篇YOLOV5 /onnx模型转换成rknn

上两篇文章讲述了pytorch模型下best.pt转换成onnx模型,以及将onnx进行简化成为best-sim.onnx, 接下来这篇文章讲述如何将onnx模型转换成rknn模型,转换成该模型是为了在rk3568上运行 1.创建share文件夹 文件夹包含以下文件best-sim.onnx,rknn-tookit2-…

停止在 React 组件回调中使用箭头函数!

在构建 React 应用时,许多开发者都喜欢使用箭头函数,因为它们简洁易用。但你知道吗,在组件回调中直接使用箭头函数可能会导致一些性能问题?在本文中,我们将分析这种情况发生的原因,并探讨你应该考虑的最佳实践。 什么是箭头函数? 在深入讨论最佳实践之前…

从【人工智能】到【计算机视觉】,【深度学习】引领的未来科技创新与变革

前几天偶然发现了一个超棒的人工智能学习网站,内容通俗易懂,讲解风趣幽默,简直让人欲罢不能。忍不住分享给大家,点击这里立刻跳转,开启你的AI学习之旅吧! 前言 – 人工智能教程https://www.captainbed.cn/l…

基于Spring Boot的装饰工程管理系统论文

摘 要 如今社会上各行各业,都喜欢用自己行业的专属软件工作,互联网发展到这个时候,人们已经发现离不开了互联网。新技术的产生,往往能解决一些老技术的弊端问题。因为传统装饰工程项目信息管理难度大,容错率低&#x…