python序列化、反序列化函数的参数,用于问题复现

embedded/2024/10/18 3:15:28/

python序列化、反序列化函数的参数,用于问题复现

  • 一.代码
  • 二.输出

一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套
Tensor,ndarray,int,float,list,tuple

一.代码

python"># -*- coding: utf-8 -*-
'''
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套:
Tensor,ndarray,int,float,list,tuple
'''import torch
import numpy as np
import pickle
from dataclasses import dataclass
from typing import Anyvar_save_path="vars.pkl"@dataclass
class DataDescriptor:class_name: Anyshape: Anyvalue: Anydtype: Anydef data(self):if self.class_name=="Tensor":return torch.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["int","float"]:return self.valueelif self.class_name in ["ndarray"]:return np.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["list","tuple"]:output=[]for t in self.value:output.append(t.data())return outputelse:raise f"Unkown:{self.class_name}"def __repr__(self) -> str:output_str=[]if self.shape:output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))if self.value:if self.class_name in ["list","tuple"]:for t in self.value:output_str.append(str(t))else:output_str.append(str(self.value))if self.dtype and self.class_name in ["Tensor","ndarray"]:output_str.append(str(self.dtype))return "{}({})".format(self.class_name,"-".join(output_str))class InputDescriptor:def __init__(self) -> None:self.input_vars=[]self.input_kwargs={}def serialize(self,path):with open(path,"wb") as f:pickle.dump(self,f)@classmethoddef deserialize(cls,path):with open(path,"rb") as f:return pickle.load(f)def data(self):input_vars=[]input_kwargs={}for var in self.input_vars:input_vars.append(var.data())for k,v in self.input_kwargs.items():input_kwargs[k]=v.data()return input_vars,input_kwargsdef _save_var(self,v):class_name=v.__class__.__name__if class_name=="Tensor":return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["int","float"]:return DataDescriptor(class_name,None,v,type(v))elif class_name in ["ndarray"]:return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["list","tuple"]:output=[]for t in v:output.append(self._save_var(t))return DataDescriptor(class_name,None,output,None)else:raise f"Unkown:{class_name}"def save_vars(self,*args,**kwargs):for arg in args:self.input_vars.append(self._save_var(arg))for k,v in kwargs.items():self.input_kwargs[k]=self._save_var(v)def __repr__(self) -> str:return str(self.input_vars) + "#" + str(self.input_kwargs)     
def do_something(*args,**kwargs):'''某个函数'''print(args,kwargs)  # 保存输入参数,序列化保存到文件desc=InputDescriptor()desc.save_vars(*args,**kwargs)desc.serialize(var_save_path)return Truedef main():# 1.准备输入参数input_vars=[]input_kwargs={}input_vars.append(torch.zeros((1,23,128),dtype=torch.float32))input_vars.append(1)input_vars.append(2.0)input_vars.append((1,2,3))input_vars.append(np.zeros((2,3,4),dtype=np.float32))input_vars.append([ torch.zeros((1,23,128),dtype=torch.float32),torch.zeros((1,23,128),dtype=torch.float32)])input_vars.append([np.zeros((2,3,4),dtype=np.float32)])input_kwargs["a"]=1input_kwargs["b"]=4input_kwargs["c"]=input_varsinput_kwargs["d"]=torch.zeros((1,23,128),dtype=torch.float32)# 2.调用用户函数,在函数中对参数序列化do_something(*input_vars,**input_kwargs)print("#"*32+" reproduce "+"#"*32)# 3.加载序列化文件desc=InputDescriptor.deserialize(var_save_path)# 4.生成参数_input_vars,_input_kwargs=desc.data()# 5.再次调用用户函数do_something(*_input_vars,**_input_kwargs)# 6.打印参数print(desc)if __name__ == "__main__":  main()

二.输出

[Tensor(shape:(1,23,128)-torch.float32), int(1), float(2.0), tuple(int(1)-int(2)-int(3)), 
ndarray(shape:(2,3,4)-float32), list(
Tensor(shape:(1,23,128)-torch.float32)-Tensor(shape:(1,23,128)-torch.float32)), 
list(ndarray(shape:(2,3,4)-float32))]
#{'a': int(1), 'b': int(4),
'c': list(Tensor(shape:(1,23,128)-torch.float32)
-int(1)-float(2.0)-tuple(int(1)-int(2)-int(3))
-ndarray(shape:(2,3,4)-float32)-list(Tensor(shape:(1,23,128)
-torch.float32)-Tensor(shape:(1,23,128)-torch.float32))
-list(ndarray(shape:(2,3,4)-float32))), 'd': Tensor(shape:(1,23,128)-torch.float32)}

http://www.ppmy.cn/embedded/26451.html

相关文章

【分享】如何将word格式文档转化为PDF格式

在日常的办公和学习中,我们经常需要将Word文档转换为PDF格式。PDF作为一种通用的文件格式,具有跨平台、易读性高等优点,因此在许多场合下都更为适用。那么,如何实现Word转PDF呢?本文将介绍几种常用的方法,帮…

Ubuntu Linux完全入门视频教程

Ubuntu Linux完全入门视频教程 UbuntuLinux完全入门视频教程1.rar UbuntuLinux亮全入门视频教程10.ra UbuntuLinux亮全入门视频教程11.ra UbuntuLinux完全入门视频教程12.ra UbuntuLinux亮全入门视频教程13.ra UbuntuLinux完全入门视频教程14.rar UbuntuLinux完全入门视频教程…

04 Docker练习赛从0开始到 docker 镜像提交

1.1 本地安装 docker 工具 这里以ubutun下安装docker为例,其他操作系统安装命令略有不同,可自行百度。(建议使用阿里源安装速度快) sudo apt install docker.io如果你本地有gpu,请继续执行如下命令以支持gpu调用: 注意: 英伟达对 docker 支持的 linux 发行版:https:/…

测试工程师面试准备(软硬件)

您好,我叫XXX。学历XX,XXX专业毕业。X年X月份毕业,但是去年二月份已经找到工作开始实习了,目前工作一年了,这一年的过程中我主要负责软件的开发和测试和软硬件联调测试工作。具体来说就是,在软件开发完成后…

nacos(docker部署)+springboot集成

文章目录 说明零nacos容器部署初始化配置高级配置部分访问权限控制命名空间设置新建配置文件 springboot配置nacos添加依赖编写测试controller 说明 nacos容器部署采用1Panel运维面板,进行部署操作,简化操作注意提前安装好1Panel和配置完成docker镜像加…

为什么 Facebook 不使用 Git?

在编程的世界里,Git 就像水一样常见,以至于我们认为它是创建和管理代码更改的唯一可行的工具。 前 Facebook 员工,2024 年 首先,我为什么关心? 我致力于构建 Graphite,它从根本上受到 Facebook 内部工具的…

头歌:Spark Streaming

第1关:套接字流实现黑名单过滤 简介 套接字流是通过监听Socket端口接收的数据,相当于Socket之间的通信,任何用户在用Socket(套接字)通信之前,首先要先申请一个Socket号,Socket号相当于该用户…

等保测评:网络安全合规的基石

在数字化时代,网络安全已成为国家安全战略的重要组成部分。信息安全等级保护测评(等保测评)作为网络安全合规的核心,对于维护网络空间的安全稳定、保护企业和个人的信息资产具有不可替代的作用。 ## 一、等保测评的法律地位 等保…