python序列化、反序列化函数的参数,用于问题复现
- 一.代码
- 二.输出
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套
Tensor,ndarray,int,float,list,tuple
一.代码
python"># -*- coding: utf-8 -*-
'''
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套:
Tensor,ndarray,int,float,list,tuple
'''import torch
import numpy as np
import pickle
from dataclasses import dataclass
from typing import Anyvar_save_path="vars.pkl"@dataclass
class DataDescriptor:class_name: Anyshape: Anyvalue: Anydtype: Anydef data(self):if self.class_name=="Tensor":return torch.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["int","float"]:return self.valueelif self.class_name in ["ndarray"]:return np.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["list","tuple"]:output=[]for t in self.value:output.append(t.data())return outputelse:raise f"Unkown:{self.class_name}"def __repr__(self) -> str:output_str=[]if self.shape:output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))if self.value:if self.class_name in ["list","tuple"]:for t in self.value:output_str.append(str(t))else:output_str.append(str(self.value))if self.dtype and self.class_name in ["Tensor","ndarray"]:output_str.append(str(self.dtype))return "{}({})".format(self.class_name,"-".join(output_str))class InputDescriptor:def __init__(self) -> None:self.input_vars=[]self.input_kwargs={}def serialize(self,path):with open(path,"wb") as f:pickle.dump(self,f)@classmethoddef deserialize(cls,path):with open(path,"rb") as f:return pickle.load(f)def data(self):input_vars=[]input_kwargs={}for var in self.input_vars:input_vars.append(var.data())for k,v in self.input_kwargs.items():input_kwargs[k]=v.data()return input_vars,input_kwargsdef _save_var(self,v):class_name=v.__class__.__name__if class_name=="Tensor":return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["int","float"]:return DataDescriptor(class_name,None,v,type(v))elif class_name in ["ndarray"]:return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["list","tuple"]:output=[]for t in v:output.append(self._save_var(t))return DataDescriptor(class_name,None,output,None)else:raise f"Unkown:{class_name}"def save_vars(self,*args,**kwargs):for arg in args:self.input_vars.append(self._save_var(arg))for k,v in kwargs.items():self.input_kwargs[k]=self._save_var(v)def __repr__(self) -> str:return str(self.input_vars) + "#" + str(self.input_kwargs)
def do_something(*args,**kwargs):'''某个函数'''print(args,kwargs) # 保存输入参数,序列化保存到文件desc=InputDescriptor()desc.save_vars(*args,**kwargs)desc.serialize(var_save_path)return Truedef main():# 1.准备输入参数input_vars=[]input_kwargs={}input_vars.append(torch.zeros((1,23,128),dtype=torch.float32))input_vars.append(1)input_vars.append(2.0)input_vars.append((1,2,3))input_vars.append(np.zeros((2,3,4),dtype=np.float32))input_vars.append([ torch.zeros((1,23,128),dtype=torch.float32),torch.zeros((1,23,128),dtype=torch.float32)])input_vars.append([np.zeros((2,3,4),dtype=np.float32)])input_kwargs["a"]=1input_kwargs["b"]=4input_kwargs["c"]=input_varsinput_kwargs["d"]=torch.zeros((1,23,128),dtype=torch.float32)# 2.调用用户函数,在函数中对参数序列化do_something(*input_vars,**input_kwargs)print("#"*32+" reproduce "+"#"*32)# 3.加载序列化文件desc=InputDescriptor.deserialize(var_save_path)# 4.生成参数_input_vars,_input_kwargs=desc.data()# 5.再次调用用户函数do_something(*_input_vars,**_input_kwargs)# 6.打印参数print(desc)if __name__ == "__main__": main()
二.输出
[Tensor(shape:(1,23,128)-torch.float32), int(1), float(2.0), tuple(int(1)-int(2)-int(3)),
ndarray(shape:(2,3,4)-float32), list(
Tensor(shape:(1,23,128)-torch.float32)-Tensor(shape:(1,23,128)-torch.float32)),
list(ndarray(shape:(2,3,4)-float32))]
#{'a': int(1), 'b': int(4),
'c': list(Tensor(shape:(1,23,128)-torch.float32)
-int(1)-float(2.0)-tuple(int(1)-int(2)-int(3))
-ndarray(shape:(2,3,4)-float32)-list(Tensor(shape:(1,23,128)
-torch.float32)-Tensor(shape:(1,23,128)-torch.float32))
-list(ndarray(shape:(2,3,4)-float32))), 'd': Tensor(shape:(1,23,128)-torch.float32)}