拦截pytorch算子,dump输入输出

ops/2024/10/24 11:19:07/

拦截pytorch算子,dump输入输出

  • 一.代码
  • 二.输出

希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.

一.代码

python">import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Templatedevice="cuda"class Attention(nn.Module):def __init__(self,max_seq_len,head_dim,flash):super().__init__()self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')self.dropout=0self.attn_dropout = nn.Dropout(self.dropout)self.head_dim=head_dimif not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)mask = torch.triu(mask, diagonal=1).half().to(device)self.register_buffer("mask", mask)		def forward(self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):if self.flash:output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)else:_xk=xk.clone()t=_xk.transpose(2, 3)scores = torch.matmul(xq,t)scores = scores/math.sqrt(self.head_dim)a=self.mask[:, :, :seqlen, :seqlen]scores = torch.add(scores,a)scores = F.softmax(scores.float(), dim=-1)scores = scores.type_as(xq)scores = self.attn_dropout(scores)output = torch.matmul(scores, xv)  return outputlock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):if isinstance(args,torch.Tensor):print(name,index,args.shape)global gindexlock.acquire()torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))gindex+=1lock.release()if isinstance(args,tuple):for idx,x in enumerate(args):save_tensor(name,x,index+idx)op_template=Template('''      
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):save_tensor("{{name}}-input",args)    global native1_{{new_name}}             ret=native1_{{new_name}}(*args, **kwargs)save_tensor("{{name}}-output",ret)   return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')for op in dir(torch.Tensor):if op in ["__iter__","shape","dim","unbind","normal_","data","item","numel","save","has_names","data_ptr","untyped_storage","storage_offset","size","stride","triu","half","is_floating_point","to","ones","randint","ones_like"]:continueif getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:continuenew_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")exec(op_template.render(name=op,new_name=new_name))op_template=Template('''      
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):save_tensor("{{name}}-input",args)    global native2_{{new_name}}             ret=native2_{{new_name}}(*args, **kwargs)save_tensor("{{name}}-output",ret) return ret
setattr(torch, '{{name}}', {{new_name}})
''')for op in dir(torch):if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr","untyped_storage","storage_offset","size","stride","triu","is_floating_point","to","ones","randint","full","reshape","ones_like"]:continueif getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:continuenew_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")exec(op_template.render(name=op,new_name=new_name))def hook_backwards(loss, cached):if loss is None:return    def posthook(*args,**kwargs):save_tensor(loss.__class__.__name__,args)def prehook(*args,**kwargs):passloss.register_prehook(prehook)loss.register_hook(posthook)cached.add(loss)for _, child in enumerate(loss.next_functions):if child[0] not in cached:hook_backwards(child[0],cached)def main(flash,bs, n_local_heads, seqlen, head_dim):torch.random.manual_seed(1)q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)q.data.normal_(0, 0.1)k.data.normal_(0, 0.1)v.data.normal_(0, 0.1)q=Variable(q, requires_grad=True).to(device)k=Variable(k, requires_grad=True).to(device)v=Variable(v, requires_grad=True).to(device)gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)loss_func=nn.CrossEntropyLoss().to(device)model=Attention(seqlen,head_dim,flash).half().to(device)optim = torch.optim.SGD([q,k,v], lr=1.1)for i in range(1):output = model(q,k,v)loss=loss_func(output.reshape(-1,head_dim),gt)hook_backwards(loss.grad_fn, cached=set())loss.backward()  optim.step()print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

二.输出

reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016

http://www.ppmy.cn/ops/24890.html

相关文章

MO干货 | Matrixone-Operator 设计与实现

作者:吴叶磊 MO研发工程师 目录 Part 1.MatrixOne-Operator 设计 Part 2.集群 API 设计 Part 3.控制器实现 Part 4.应用状态管理 Part 5.总结 Part 1 MatrixOne-Operator 设计 尽管 K8S 原生提供了 StatefulSet API 来服务有状态应用的编排,但由于…

Grafana – unable to login “User already exists”

The Issue When trying to log into Grafana Web UI using an OIDC provider, in my case, Dex. The login would fail due to the error “User already exists”, after some time. This happened for any users given access via the OIDC. The Cause This looks to happ…

改造BeanUtils,优雅实现List数据拷贝

BeanUtils.copyProperties();确实为我们做了很多事情,虽然不能完美完成深拷贝,但是对于 po、vo、dto 的拷贝已经足够用了。可还是有一些不够完美的地方。 不足几点如下: ①. 不能拷贝 list,而拷贝 list 的情况又大量存在&#x…

Docker容器管理详解

引言 Docker作为当前流行的容器化技术,使得应用的部署、扩展和管理变得更加容易。本文将详细介绍Docker容器的概念、特点,以及如何使用Docker命令进行容器管理。 一,Docker容器概念与特点 Docker容器是一种轻量级、可移植、自包含的运行环…

【Linux系统编程】29.线程、pthread_self、pthread_create

目录 线程 三级映射 线程共享资源 线程非共享资源 线程优缺点 优点 缺点 pthread_self 返回值 pthread_create 参数thread 参数attr 参数3 参数4 返回值 测试代码1 测试结果 测试代码2 测试结果 线程 线程概念: 进程:有独立的进程地…

【定制化体验:使用Spring Boot自动配置,打造个性化Starter】

项目结构 Pom <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4…

Dokcer容器分布式搭建LNMP+wordpress论坛

目录 引言 一、架构环境 二、搭建容器 &#xff08;一&#xff09;自定义网络 &#xff08;二&#xff09;搭建nginx容器 1.文件准备 2.查看与编辑文件 3.生成镜像 4.创建容器 &#xff08;三&#xff09;搭建MySQL容器 1.文件准备 2.查看与编辑文件 3.生成镜像 …

域名系统与域名解析服务器DNS

IPv4根域名服务器 IPv4根服务器全球只有13台&#xff08;这13台IPv4根域名服务器名字分别为“A”至“M”&#xff09;&#xff0c;1个为主根服务器在美国&#xff0c;由美国互联网机构Network Solutions运作。其余12个均为辅根服务器&#xff0c;其中9个在美国&#xff0c;2个…