python序列化、反序列化函数的参数,用于问题复现

news/2024/10/21 6:03:22/

python序列化、反序列化函数的参数,用于问题复现

  • 一.代码
  • 二.输出

一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套
Tensor,ndarray,int,float,list,tuple

一.代码

python"># -*- coding: utf-8 -*-
'''
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套:
Tensor,ndarray,int,float,list,tuple
'''import torch
import numpy as np
import pickle
from dataclasses import dataclass
from typing import Anyvar_save_path="vars.pkl"@dataclass
class DataDescriptor:class_name: Anyshape: Anyvalue: Anydtype: Anydef data(self):if self.class_name=="Tensor":return torch.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["int","float"]:return self.valueelif self.class_name in ["ndarray"]:return np.zeros(self.shape,dtype=self.dtype)elif self.class_name in ["list","tuple"]:output=[]for t in self.value:output.append(t.data())return outputelse:raise f"Unkown:{self.class_name}"def __repr__(self) -> str:output_str=[]if self.shape:output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))if self.value:if self.class_name in ["list","tuple"]:for t in self.value:output_str.append(str(t))else:output_str.append(str(self.value))if self.dtype and self.class_name in ["Tensor","ndarray"]:output_str.append(str(self.dtype))return "{}({})".format(self.class_name,"-".join(output_str))class InputDescriptor:def __init__(self) -> None:self.input_vars=[]self.input_kwargs={}def serialize(self,path):with open(path,"wb") as f:pickle.dump(self,f)@classmethoddef deserialize(cls,path):with open(path,"rb") as f:return pickle.load(f)def data(self):input_vars=[]input_kwargs={}for var in self.input_vars:input_vars.append(var.data())for k,v in self.input_kwargs.items():input_kwargs[k]=v.data()return input_vars,input_kwargsdef _save_var(self,v):class_name=v.__class__.__name__if class_name=="Tensor":return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["int","float"]:return DataDescriptor(class_name,None,v,type(v))elif class_name in ["ndarray"]:return DataDescriptor(class_name,list(v.shape),None,v.dtype)elif class_name in ["list","tuple"]:output=[]for t in v:output.append(self._save_var(t))return DataDescriptor(class_name,None,output,None)else:raise f"Unkown:{class_name}"def save_vars(self,*args,**kwargs):for arg in args:self.input_vars.append(self._save_var(arg))for k,v in kwargs.items():self.input_kwargs[k]=self._save_var(v)def __repr__(self) -> str:return str(self.input_vars) + "#" + str(self.input_kwargs)     
def do_something(*args,**kwargs):'''某个函数'''print(args,kwargs)  # 保存输入参数,序列化保存到文件desc=InputDescriptor()desc.save_vars(*args,**kwargs)desc.serialize(var_save_path)return Truedef main():# 1.准备输入参数input_vars=[]input_kwargs={}input_vars.append(torch.zeros((1,23,128),dtype=torch.float32))input_vars.append(1)input_vars.append(2.0)input_vars.append((1,2,3))input_vars.append(np.zeros((2,3,4),dtype=np.float32))input_vars.append([ torch.zeros((1,23,128),dtype=torch.float32),torch.zeros((1,23,128),dtype=torch.float32)])input_vars.append([np.zeros((2,3,4),dtype=np.float32)])input_kwargs["a"]=1input_kwargs["b"]=4input_kwargs["c"]=input_varsinput_kwargs["d"]=torch.zeros((1,23,128),dtype=torch.float32)# 2.调用用户函数,在函数中对参数序列化do_something(*input_vars,**input_kwargs)print("#"*32+" reproduce "+"#"*32)# 3.加载序列化文件desc=InputDescriptor.deserialize(var_save_path)# 4.生成参数_input_vars,_input_kwargs=desc.data()# 5.再次调用用户函数do_something(*_input_vars,**_input_kwargs)# 6.打印参数print(desc)if __name__ == "__main__":  main()

二.输出

[Tensor(shape:(1,23,128)-torch.float32), int(1), float(2.0), tuple(int(1)-int(2)-int(3)), 
ndarray(shape:(2,3,4)-float32), list(
Tensor(shape:(1,23,128)-torch.float32)-Tensor(shape:(1,23,128)-torch.float32)), 
list(ndarray(shape:(2,3,4)-float32))]
#{'a': int(1), 'b': int(4),
'c': list(Tensor(shape:(1,23,128)-torch.float32)
-int(1)-float(2.0)-tuple(int(1)-int(2)-int(3))
-ndarray(shape:(2,3,4)-float32)-list(Tensor(shape:(1,23,128)
-torch.float32)-Tensor(shape:(1,23,128)-torch.float32))
-list(ndarray(shape:(2,3,4)-float32))), 'd': Tensor(shape:(1,23,128)-torch.float32)}

http://www.ppmy.cn/news/1447332.html

相关文章

智能科技的飞跃:LLAMA3引领的人工智能新时代

大家好!相信大家对于AI(人工智能)的发展已经有了一定的了解,但你是否意识到,到了2024年,AI已经变得如此强大和普及,带来了我们从未想象过的便利和创新呢?让我们一起来看看AI在这个时…

Windows之隐藏特殊文件夹(自定义快捷桌面程序)

作者主页:点击! 创作时间:2024年5月1日12点55分 祝大家劳动节快乐~ Windows中的特殊文件夹是指一些预定义的文件夹,用于存储特定类型的数据或文件。这些文件夹通常由操作系统或应用程序使用,但用户也可以访问和管理它…

C语言:数据结构(双向链表)

目录 1、双向链表的结构2、顺序表和双向链表的优缺点分析3、双向链表的实现 1、双向链表的结构 注意:这⾥的“带头“跟前面我们说的“头节点”是两个概念,实际前面的在单链表阶段称呼不严谨,但是为了更好的理解就直接称为单链表的头节点。 带…

汽车信息安全入门总结(2)

目录 1.引入 2.汽车信息安全技术 3.密码学基础知识 4.小结 1.引入 上篇汽车信息安全入门总结(1)-CSDN博客主要讲述了汽车信息安全应该关注的点,以及相关法规和标准,限于篇幅,继续聊信息安全相关技术以及需要掌握的密码学基础知识。 2.汽…

Linux 第十七章

🐶博主主页:ᰔᩚ. 一怀明月ꦿ ❤️‍🔥专栏系列:线性代数,C初学者入门训练,题解C,C的使用文章,「初学」C,linux 🔥座右铭:“不要等到什么都没有了…

微信小程序常用的api

基础API: wx.request:用于发起网络请求,支持GET、POST等方式,是获取网络数据的主要手段。wx.showToast:显示消息提示框,通常用于向用户展示操作成功、失败或加载中等状态。wx.showModal:显示模态…

打水问题(贪心算法)

题目:有n个人排队到r个水龙头去打水,他们装满水桶的时间t1、t2………tn为整数且各不相等,应如何安排他们的打水顺序才能使他们总共花费的时间最少?通过键盘输入排队打水的人数以及每人打水的时间和水龙头数,使用贪心算…

Spring Data JPA数据批量插入、批量更新真的用对了吗

Spring Data JPA系列 1、SpringBoot集成JPA及基本使用 2、Spring Data JPA Criteria查询、部分字段查询 3、Spring Data JPA数据批量插入、批量更新真的用对了吗 前言 在前两篇文章已经介绍过,在使用Spring Data JPA时,DAO层的Respository通过继承J…