pytest学习-pytorch单元测试

ops/2024/9/23 2:29:54/

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as distos.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" device="cpu"
device_type="cpu"
device_name="cpu"try:if torch.cuda.is_available():     device_name=torch.cuda.get_device_name().replace(" ","")device="cuda:0"device_type="cuda"ccl_backend='nccl'
except:passhost_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()if not os.path.exists(metric_data_root):os.makedirs(metric_data_root)def device_warmup(device):'''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''left = torch.rand([128,512], dtype = torch.float16).to(device)right = torch.rand([512,128], dtype = torch.float16).to(device)out=torch.matmul(left,right)torch.cuda.synchronize()torch.manual_seed(1) 
np.random.seed(1)def loop_decorator(loops,rank=0):'''循环装饰器,用于统计函数的执行时间,内存占用等'''def decorator(func):def wrapper(*args,**kwargs):latency=[]memory_allocated_t0=torch.cuda.memory_allocated(rank)for _ in range(loops):input_copy=[x.clone() for x in args]beg= datetime.now().timestamp() * 1e6pred= func(*input_copy)gt=kwargs["golden"]torch.cuda.synchronize()end=datetime.now().timestamp() * 1e6mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()latency.append(end-beg)memory_allocated_t1=torch.cuda.memory_allocated(rank)avg_latency=np.mean(latency[len(latency)//2:]).round(3)first_latency=latency[0]return { "first_latency":first_latency,"avg_latency":avg_latency,"memory_allocated":memory_allocated_t1-memory_allocated_t0,"mse":mse}return wrapperreturn decoratorclass TorchUtMetrics:'''用于统计测试结果,比较之前的最小值'''def __init__(self,ut_name,thresold=0.2,rank=0):self.ut_name=f"{ut_name}_{rank}"self.thresold=thresoldself.rank=rankself.data={"ut_name":self.ut_name,"metrics":[]}self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")try:with open(self.metrics_path,"r") as f:self.data=json.loads(f.read())except:passdef __enter__(self):self.beg= datetime.now().timestamp() * 1e6return selfdef __exit__(self, exc_type, exc_val, exc_tb):        self.report()self.save_data()def save_data(self):with open(self.metrics_path,"w") as f:f.write(json.dumps(self.data,indent=4))def set_metrics(self,metrics):self.end=datetime.now().timestamp() * 1e6item=collections.OrderedDict()item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')item["sdk_version"]=sdk_versionitem["device_name"]=device_nameitem["host_name"]=host_nameitem["metrics"]=metricsitem["metrics"]["e2e_time"]=self.end-self.begself.cur_item=itemself.data["metrics"].append(self.cur_item)def get_metric_names(self):return self.data["metrics"][0]["metrics"].keys()def get_min_metric(self,metric_name,devicename=None):min_value=0min_value_index=-1for idx,item in enumerate(self.data["metrics"]):if devicename and (devicename!=item['device_name']):                continue            val=float(item["metrics"][metric_name])if min_value_index==-1 or val<min_value:min_value=valmin_value_index=idxreturn min_value,min_value_indexdef get_metric_info(self,index):metrics=self.data["metrics"][index]return f'{metrics["device_name"]}@{metrics["sdk_version"]}'def report(self):assert len(self.data["metrics"])>0for metric_name in self.get_metric_names():min_value,min_value_index=self.get_min_metric(metric_name)min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)cur_value=float(self.cur_item["metrics"][metric_name])print(f"-------------------------------{metric_name}-------------------------------")print(f"{cur_value}#{device_name}@{sdk_version}")if min_value_index_same_dev>=0:print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")if min_value_index>=0:print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):#如果不满足条件,则跳过这个测试@unittest.skipIf(device_count>1, "Not enough devices") def test_todo(self):print(".TODO")#框架会自动遍历以下参数组合@parametrize("shape", [(10240,20480),(128,256)])@parametrize("dtype", [torch.float16,torch.float32])def test_clone(self,shape,dtype):#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5)def run(input_dev):output=input_dev.clone()return output#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:input_host=torch.ones(shape,dtype=dtype)*np.random.rand()input_dev=input_host.to(device)metrics=run(input_dev,golden=input_host.cpu())m.set_metrics(metrics)assert(metrics["mse"]==0)instantiate_parametrized_tests(TestCaseClone)if __name__ == "__main__":run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):'''CCL测试用例'''def _create_process_group_vccl(self, world_size, store):dist.init_process_group(ccl_backend, world_size=world_size, rank=self.rank, store=store)        pg = dist.distributed_c10d._get_default_group()return pgdef setUp(self):super().setUp()self._spawn_processes()def tearDown(self):super().tearDown()try:os.remove(self.file_name)except OSError:pass@propertydef world_size(self):return 4#框架会自动遍历以下参数组合@unittest.skipIf(device_count<4, "Not enough devices") @parametrize("op",[dist.ReduceOp.SUM])@parametrize("shape", [(1024,8192)])@parametrize("dtype", [torch.int64])def test_allreduce(self,op,shape,dtype):if self.rank >= self.world_size:returnstore = dist.FileStore(self.file_name, self.world_size)pg = self._create_process_group_vccl(self.world_size, store)if not torch.distributed.is_initialized():returntorch.cuda.set_device(self.rank)device = torch.device(device_type,self.rank)device_warmup(device)#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5,rank=self.rank)def run(input_dev):dist.all_reduce(input_dev, op=op)return input_dev#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]gt_=gt[0]for i in range(1,self.world_size):gt_=gt_+gt[i]input_dev=input_host.to(device)metrics=run(input_dev,golden=gt_)m.set_metrics(metrics)assert(metrics["mse"]==0)dist.destroy_process_group(pg)instantiate_parametrized_tests(TestCCL)if __name__ == "__main__":run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述


http://www.ppmy.cn/ops/6161.html

相关文章

MapReduce分区机制(Hadoop)

在MapReduce中&#xff0c;分区&#xff08;Partitioning&#xff09;是将Map阶段输出的键值对根据某种规则分发到不同的Reduce任务上的过程。这个过程非常关键&#xff0c;因为它直接影响到了Reduce阶段的负载均衡和性能。 1. 哈希分区&#xff08;Hash Partitioning&#xf…

亚信安全数据安全运营平台DSOP新版本发布 注入AI研判升维

在当今快速发展的数字经济时代&#xff0c;企业对于数据的依赖日益加深&#xff0c;数据安全已成为企业的生命线。亚信安全推出数据安全运营平台DSOP全新版本&#xff0c;正是为满足企业对数据安全的高度需求而设计。这款平台以其卓越的能力和技术优势&#xff0c;为企业的数据…

【php快速上手(十一)】

目录 PHP快速上手&#xff08;十一&#xff09;PHP 连接数据库和创建数据库PHP 连接数据库使用 MySQLi连接 MySQL 数据库使用 PDO 连接 MySQL 数据库 PHP创建数据库使用MySQLi创建MySQL数据库&#xff1a;使用PDO创建MySQL数据库&#xff1a; PHP快速上手&#xff08;十一&…

JVM之JVM栈的详细解析

Java 栈 Java 虚拟机栈&#xff1a;Java Virtual Machine Stacks&#xff0c;每个线程运行时所需要的内存 每个方法被执行时&#xff0c;都会在虚拟机栈中创建一个栈帧 stack frame&#xff08;一个方法一个栈帧&#xff09; Java 虚拟机规范允许 Java 栈的大小是动态的或者是…

「 网络安全常用术语解读 」漏洞利用交换VEX详解

漏洞利用交换&#xff08;Vulnerability Exploitability eXchange&#xff0c;简称VEX&#xff09;是一个信息安全领域的标准&#xff0c;旨在提供关于软件漏洞及其潜在利用的实时信息。根据美国政府发布的用例(PDF)&#xff0c;由美国政府开发的漏洞利用交换(VEX)使供应商和用…

韩国机器人公司Rainbow Robotics推出RB-Y1轮式双臂机器人

文 | BFT机器人 近日&#xff0c;韩国机器人领域的佼佼者Rainbow Robotics揭开了RB-Y1移动机器人的神秘面纱&#xff0c;这款机器人以其创新的设计和卓越的功能引起了业界的广泛关注。与此同时&#xff0c;Rainbow Robotics还携手舍弗勒集团&#xff08;提供汽车、工业技术服务…

js 遍历数据结构,使不符合条件的全部删除

js 遍历数据结构&#xff0c;使不符合条件的全部删除 let newSourceJSON.parse(JSON.stringify(state.treeData))state.expandedKeys[]checkedKeys.map((item:any)>{loop(newSource,{jsonPath:item.split(&)[1]},state.expandedKeys)})function removeUnwantedNodes(tre…

【算法一则】矩阵置零 【矩阵】【空间复用】

题目 给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,1,1],[1,0,1],[1,1,1]] 输出&#xff1a;[[1,0,1],[0,0,0],[1,0,1]]示例 2&#xff1a; …