pytest学习-pytorch单元测试

news/2024/10/18 16:49:29/

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as distos.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" device="cpu"
device_type="cpu"
device_name="cpu"try:if torch.cuda.is_available():     device_name=torch.cuda.get_device_name().replace(" ","")device="cuda:0"device_type="cuda"ccl_backend='nccl'
except:passhost_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()if not os.path.exists(metric_data_root):os.makedirs(metric_data_root)def device_warmup(device):'''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''left = torch.rand([128,512], dtype = torch.float16).to(device)right = torch.rand([512,128], dtype = torch.float16).to(device)out=torch.matmul(left,right)torch.cuda.synchronize()torch.manual_seed(1) 
np.random.seed(1)def loop_decorator(loops,rank=0):'''循环装饰器,用于统计函数的执行时间,内存占用等'''def decorator(func):def wrapper(*args,**kwargs):latency=[]memory_allocated_t0=torch.cuda.memory_allocated(rank)for _ in range(loops):input_copy=[x.clone() for x in args]beg= datetime.now().timestamp() * 1e6pred= func(*input_copy)gt=kwargs["golden"]torch.cuda.synchronize()end=datetime.now().timestamp() * 1e6mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()latency.append(end-beg)memory_allocated_t1=torch.cuda.memory_allocated(rank)avg_latency=np.mean(latency[len(latency)//2:]).round(3)first_latency=latency[0]return { "first_latency":first_latency,"avg_latency":avg_latency,"memory_allocated":memory_allocated_t1-memory_allocated_t0,"mse":mse}return wrapperreturn decoratorclass TorchUtMetrics:'''用于统计测试结果,比较之前的最小值'''def __init__(self,ut_name,thresold=0.2,rank=0):self.ut_name=f"{ut_name}_{rank}"self.thresold=thresoldself.rank=rankself.data={"ut_name":self.ut_name,"metrics":[]}self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")try:with open(self.metrics_path,"r") as f:self.data=json.loads(f.read())except:passdef __enter__(self):self.beg= datetime.now().timestamp() * 1e6return selfdef __exit__(self, exc_type, exc_val, exc_tb):        self.report()self.save_data()def save_data(self):with open(self.metrics_path,"w") as f:f.write(json.dumps(self.data,indent=4))def set_metrics(self,metrics):self.end=datetime.now().timestamp() * 1e6item=collections.OrderedDict()item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')item["sdk_version"]=sdk_versionitem["device_name"]=device_nameitem["host_name"]=host_nameitem["metrics"]=metricsitem["metrics"]["e2e_time"]=self.end-self.begself.cur_item=itemself.data["metrics"].append(self.cur_item)def get_metric_names(self):return self.data["metrics"][0]["metrics"].keys()def get_min_metric(self,metric_name,devicename=None):min_value=0min_value_index=-1for idx,item in enumerate(self.data["metrics"]):if devicename and (devicename!=item['device_name']):                continue            val=float(item["metrics"][metric_name])if min_value_index==-1 or val<min_value:min_value=valmin_value_index=idxreturn min_value,min_value_indexdef get_metric_info(self,index):metrics=self.data["metrics"][index]return f'{metrics["device_name"]}@{metrics["sdk_version"]}'def report(self):assert len(self.data["metrics"])>0for metric_name in self.get_metric_names():min_value,min_value_index=self.get_min_metric(metric_name)min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)cur_value=float(self.cur_item["metrics"][metric_name])print(f"-------------------------------{metric_name}-------------------------------")print(f"{cur_value}#{device_name}@{sdk_version}")if min_value_index_same_dev>=0:print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")if min_value_index>=0:print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):#如果不满足条件,则跳过这个测试@unittest.skipIf(device_count>1, "Not enough devices") def test_todo(self):print(".TODO")#框架会自动遍历以下参数组合@parametrize("shape", [(10240,20480),(128,256)])@parametrize("dtype", [torch.float16,torch.float32])def test_clone(self,shape,dtype):#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5)def run(input_dev):output=input_dev.clone()return output#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:input_host=torch.ones(shape,dtype=dtype)*np.random.rand()input_dev=input_host.to(device)metrics=run(input_dev,golden=input_host.cpu())m.set_metrics(metrics)assert(metrics["mse"]==0)instantiate_parametrized_tests(TestCaseClone)if __name__ == "__main__":run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):'''CCL测试用例'''def _create_process_group_vccl(self, world_size, store):dist.init_process_group(ccl_backend, world_size=world_size, rank=self.rank, store=store)        pg = dist.distributed_c10d._get_default_group()return pgdef setUp(self):super().setUp()self._spawn_processes()def tearDown(self):super().tearDown()try:os.remove(self.file_name)except OSError:pass@propertydef world_size(self):return 4#框架会自动遍历以下参数组合@unittest.skipIf(device_count<4, "Not enough devices") @parametrize("op",[dist.ReduceOp.SUM])@parametrize("shape", [(1024,8192)])@parametrize("dtype", [torch.int64])def test_allreduce(self,op,shape,dtype):if self.rank >= self.world_size:returnstore = dist.FileStore(self.file_name, self.world_size)pg = self._create_process_group_vccl(self.world_size, store)if not torch.distributed.is_initialized():returntorch.cuda.set_device(self.rank)device = torch.device(device_type,self.rank)device_warmup(device)#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5,rank=self.rank)def run(input_dev):dist.all_reduce(input_dev, op=op)return input_dev#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]gt_=gt[0]for i in range(1,self.world_size):gt_=gt_+gt[i]input_dev=input_host.to(device)metrics=run(input_dev,golden=gt_)m.set_metrics(metrics)assert(metrics["mse"]==0)dist.destroy_process_group(pg)instantiate_parametrized_tests(TestCCL)if __name__ == "__main__":run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述


http://www.ppmy.cn/news/1427876.html

相关文章

【攻防世界】warmup

[HCTF 2018]WarmUp全网最详细解释_[hctf 2018]warmup的解-CSDN博客 php://filter 读取源码&#xff08;文件&#xff09; php://input 执行php代码&#xff0c;需要post请求提交数据 Content-Type为image/jpeg text. 绕过后缀的有文件格式有php,php3,php4,php5,pht…

面试 Python 基础八股文十问十答第二期

面试 Python 基础八股文十问十答第二期 作者&#xff1a;程序员小白条&#xff0c;个人博客 相信看了本文后&#xff0c;对你的面试是有一定帮助的&#xff01;关注专栏后就能收到持续更新&#xff01; ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 1&#xff09;为什么有了GIL还要关…

如何根据表名快速定位引用该表的Oracle存储过程

如何根据表名快速定位引用该表的Oracle存储过程 引言场景一&#xff1a;常规查询 - USER_DEPENDENCIES场景二&#xff1a;基于源码搜索 - USER_SOURCE场景三&#xff1a;复杂依赖分析总结与注意事项 引言 在数据库管理和维护过程中&#xff0c;当我们计划对某张特定表进行结构调…

接口压力测试 jmeter--入门篇(一)

一 压力测试的目的 评估系统的能力识别系统的弱点&#xff1a;瓶颈/弱点检查系统的隐藏的问题检验系统的稳定性和可靠性 二 性能测试指标以及测算 【虚拟用户数】&#xff1a;线程用户【并发数】&#xff1a;指在某一时间&#xff0c;一定数量的虚拟用户同时对系统的某个功…

《Java面试自救指南》(专题七)系统场景设计(含分布式、微服务)

文章目录 负载均衡如何实现,有哪几种方式谈谈你对微服务的理解SOA和微服务的区别CAP理论和BASE定理分布式系统需要考虑哪些问题分布式系统如何实现数据一致性如何实现分布式锁你的服务挂了怎么处理限流算法原理和应用分布式ID生成策略一致性算法(2/3pc, paxos, Raft, ZAB)淘…

(0)(0.2) 接近传感器

文章目录 前言 1 配置 2 测试 3 附加功能 前言 Copter/Rover 支持避开飞行器前方可能出现的障碍物。启用这些功能的第一步是安装一个正常工作的接近传感器。ArduPilot 最多支持 4 个传感器。 360 度激光雷达通常作为近距离传感器用于物体回避&#xff0c;但也可将多个测距…

过零可控硅光耦与随机可控硅光耦

无过零检测 推荐型号 MOC3021无过零检测 对应的数据手册 原理框图 工作电流 过零检测 推荐型号 MOC3061 原理框图 工作电流 注意事项 随机导通型是随时打开的。都是过零时关闭 也即是说&#xff1a;过零型打开的都是一个馒头波。 参考链接 过零可控硅光耦怎么用-电路知识干…

Linux中断——嵌入式Linux驱动开发

参考正点原子I.MX6U嵌入式Linux驱动开发指南 一、简介 先来简单了解一般中断的处理方法&#xff1a; ①、使能中断&#xff0c;初始化相应的寄存器。 ②、注册中断服务函数&#xff0c;也就是向 irqTable 数组的指定标号处写入中断服务函数 ③、中断发生以后进入 IRQ 中…