PyTorch多机训练Loss不一致问题排查指南:基于算子级一致性验证

server/2025/3/16 4:41:32/

比较二次训练过程中所有算子的误差,定位存在一致性问题的pytorch算子

    • 一.背景
    • 二.技术方案
      • 1.核心思路
      • 2.关键技术点
    • 三.代码

一.背景

在分布式训练场景中,观察到以下现象:

  1. 相同超参配置下,多次训练的Loss曲线存在显著差异(波动幅度>5%)
  2. 模型收敛稳定性受训练节点数影响,节点越多差异越明显
  3. 梯度检查(Gradient Check)未发现异常,初步排除模型结构问题

二.技术方案

1.核心思路

通过算子级数值一致性验证,定位导致多机训练结果不一致的PyTorch原生算子。关键技术路径:

  1. 算子拦截:利用__torch_dispatch__机制捕获所有ATen算子调用
  2. 双模校验
    • 基准模式:首次运行保存各算子输入/输出的统计特征
    • 验证模式:后续运行实时校验数值一致性
  3. 差异定位:当检测到统计特征偏离时,打印完整的调用栈信息

2.关键技术点

模块实现方案
算子拦截继承TorchDispatchMode重写调度逻辑
特征提取计算张量均值(排除形状/类型等非数值因素)
差异检测使用torch.allclose进行容差对比(默认atol=1e-4)
结果持久化按rank序列化存储基准数据到磁盘
黑名单机制过滤empty_like等非计算类算子,减少误报

三.代码

python">import torch
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import numpy as np
import torch.nn as nn
import time
import os
import pickle
import inspect
import traceback@dataclass
class _ProfilerState:cls: Anyobject: Any = Nonedef is_tensor(val):return isinstance(val, (torch.Tensor, nn.Parameter))def do_compare(rank, name, tensor):timestamp = time.time()seconds = int(timestamp)millis = int((timestamp - seconds) * 1000)now=f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(seconds))}.{millis:03d}"has_nan = torch.isnan(tensor).any()if has_nan:print(f"has_nan {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device}")return Falsecurrent_mean = torch.mean(tensor.float()).cpu().item()cache_file = f"/logs/rank_{rank}.pkl"os.makedirs(os.path.dirname(cache_file), exist_ok=True)stored_data = {}if os.path.exists(cache_file):try:with open(cache_file, 'rb') as f:stored_data = pickle.load(f)except:print("load failed:",cache_file)traceback.print_exc()if name in stored_data:if torch.allclose(torch.tensor(current_mean),torch.tensor(stored_data[name]),atol=1e-4):return Trueprint("------------------------------In-----------------------------------")print(f"MisMatch {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device} {current_mean} {stored_data[name]} min:{torch.min(tensor)} max:{torch.max(tensor)}")return Falseelse:stored_data[name] = current_meanwith open(cache_file, 'wb') as f:pickle.dump(stored_data, f)return Trueindex_counter = 0
def compare_tensor(name, tensor):global index_counterindex_counter += 1rank=torch.distributed.get_rank()if is_tensor(tensor):if not do_compare(rank,f"{name}-{index_counter}",tensor):return Falseelif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):if is_tensor(t):if not do_compare(rank,f"{name}-{index_counter}-{idx}",t):return Falsereturn Trueclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef is_allow_dump(self,name):black_list=["empty","like","zero","detach","has","view","copy","arange","fill","ones","lift_fresh","alias","scalar_tensor","clone","stack","slice","source","barrier","select","random","unsqueeze","expand","normal_"]for i in black_list:if name.find(i)>=0:return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):func_packet = func._overloadpacketop_name=f"{func}"enable_dump= self.is_allow_dump(op_name)if kwargs is None:kwargs = {}if enable_dump:torch.cuda.synchronize()if not compare_tensor(f"{op_name}-[input]", args):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i}:{frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out-----------------------------------")ret= func(*args, **kwargs)if enable_dump:torch.cuda.synchronize()if not compare_tensor(f"{op_name}-[output0]", ret):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i} {frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out0-----------------------------------")if not compare_tensor(f"{op_name}-[output1]", args):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i} {frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out1-----------------------------------")return retclass TorchDebugDumper:_CURRENT_Dumper = Nonedef __init__(self):self.p= _ProfilerState(TorchDumpDispatchMode)def __enter__(self):assert TorchDebugDumper._CURRENT_Dumper is NoneTorchDebugDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDebugDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.objectdef main():pretrain(train_valid_test_datasets_provider,model_provider,forward_step,extra_args_provider=llama_argument_handler,args_defaults={"tokenizer_type": "GPT2BPETokenizer"},)
if __name__ == "__main__":with TorchDebugDumper():main()

http://www.ppmy.cn/server/175324.html

相关文章

移远通信联合德壹发布全球首款搭载端侧大模型的AI具身理疗机器人

在汹涌澎湃的人工智能浪潮中,具身智能正从实验室构想迈向现实应用。移远通信凭借突破性的端侧AI整体解决方案,为AI机器人强势赋能,助力其实现跨行业拓展,从工业制造到服务接待,再到医疗康养,不断改写各行业…

Qt信号与槽

1.信号与槽概述 在Qt中,用户和控件的每一次交互过程称为一个事件。比如“用户点击按钮”是一个事件,“用户关闭窗口”也是一个事件。 每个事件都会发出一个信号。例如用户点击按钮会发出“按钮被点击”的信号,用户关闭窗口会发出“窗口被关闭…

Android7上移植I2C-tools

一,下载源码 cd hardware/libhardware/tests git clone https://git.kernel.org/pub/scm/utils/i2c-tools/i2c-tools.git 二, 在 i2c-tools 目录添加 Android.mk 编译文件 LOCAL_PATH: $(call my-dir)################### i2c-tools ###############…

HCIA-11.以太网链路聚合与交换机堆叠、集群

链路聚合背景 拓扑组网时为了高可用,需要网络的冗余备份。但增加冗余容易后会出现环路,所以我们部署了STP协议来破除环路。 但是,根据实际业务的需要,为网络不停的增加冗余是现实需要的一部分。 那么,为了让网络冗余…

随笔小记-本人常用桌面应用(流程图-boardmix,截图-snipaste,文件比较-beyond compare,远程控制-向日葵,解压-360压缩)

1.流程图绘画-boardmix 2.快捷截图-snipaste 3.文件与文件夹比较工具(比较文件内容差异结构差异,可合并)-beyond compare 4.远程控制-向日葵远程控制 5.压缩包的解压缩-360压缩

微信小程序threejs三维开发

微信小程序threejs开发 import * as THREE from three; const { performance, document, window, HTMLCanvasElement, requestAnimationFrame, cancelAnimationFrame, core, Event, Event0 } THREE .DHTML import Stats from three/examples/jsm/libs/stats.module.js; im…

影刀RPA安装32位与64位的差别

1. 影刀RPA概述 1.1 产品简介 影刀RPA是一款由杭州分叉智能科技有限公司研发的RPA自动化软件,致力于为各行业客户提供RPA自动化机器人产品与解决方案,能够实现PC、手机上的任何软件自动化操作。其功能特性丰富,包括桌面软件自动化、网页自动…

第十五届蓝桥杯C/C++ C 组全部题目详细题解

本文为第十五届蓝桥杯C/C C 组全部题目的详细题解,题目均来自于蓝桥杯官网,真题链接: 蓝桥杯真题卷 - 蓝桥云课 觉的有帮助或者写的不错可以点个赞,如果我题解写的有问题也欢迎评论指出,欢迎友好讨论 目录 题一:拼正方…