TFLite文件解析及格式转换

news/2025/3/5 5:42:19/

        随着深度学习越来越流行,工业生产不光在PC端应用场景丰富,在移动端也越来越凸显出深度学习的重要性及应用价值。由于嵌入式平台受存储、指令集限制,需要提供更小的网络模型,并且某些DSP平台不支持float指令。tensorflow提供TOCO转换工具能够自动生成量化为U8的TFLite文件。本文将介绍如何解析tflite的网络结构以及权重信息。

一、tflite文件格式

        Tflite文件由tensorflow提供的TOCO工具生成的轻量级模型,存储格式是flatbuffer,它是google开源的一种二进制序列化格式,同功能的像protobuf。对flatbuffer可小结为三点。

1.内容分为vtable区和数据区,vtable区保存着变量的偏移值,数据区保存着变量值;

2.要解析变量a,是在vtable区组合一层层的offset偏移量计算出总偏移,然后以总偏移到数据区中定位从而获取变量a的值。

3.一个叫schema的文本文件定义了要进行序列化和反序列化的数据结构。

具体定义的结构可以参考tensorflow源码中的schema文件:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs

二、tflite解析

         由于工作需要,本文使用了google flatbuffer开源工具flatc,flatc可以实现tflite格式到jason文件的自动转换。

flatbuffer源码:https://github.com/google/flatbuffers 

安装命令: cmake -G "Unix Makefiles" //生成MakeFile

                   make //生成flatc

                   make install //安裝flatc

安装完成后,从tensorflow源码中copy 结构文件schema.fbs到flatbuffer根目录,执行#./flatc -t schema.fbs -- mobilenet_v1_1.0_224_quant.tflite,生成对应的json文件。Json文件结构如下图所示:

operator_codes: 以列表的形式存储该网络结构用的layer种类;

subgraphs: 为每一层的具体信息具体包括:

            1)tensors.包含input、weight、bias的shape信息、量化参数以及在buffer数据区的offset值;

            2)inputs: 整个网络的输入对应的tensors索引;

            3)outputs: 整个网络的输出对应的tensors索引;

            4)operators:网络结构所需要的相关参数;

buffers: 存放weight、bias等权重信息。

三、网络结构及权重提取

       使用python的json包可以很方便的读取tflite生成的json文件。关于解析过程有几点说明:

        1.flatc转换的json文件不是标准的key-value格式,需要稍作转换给索引key加上双引号具体代码如下:

# -*- coding: UTF-8 -*-
import os
pathIn='xxx.json'
pathDst='xxx_new.json'
f = open(pathIn)             # 返回一个文件对象
line = f.readline()             # 调用文件的 readline()方法
fout = open(pathDst,'w') while line:#print(line)  #print(len(line))dstline='aaa'if line.find(':')!=-1:quoteIdx2=line.find(':')#print("line has :, and index =%d" %quoteIdx2)linenew=line[:quoteIdx2] + '"' + line[quoteIdx2:]quoteIdx1=linenew.rfind(' ',0, quoteIdx2)#print("quoteIdx1 %d" %quoteIdx1)dstline=linenew[:quoteIdx1+1] + '"' + linenew[quoteIdx1+1:]#print(dstline)fout.write(dstline+os.linesep) else:dstline=linefout.write(line) #print("No")#print dstlineline = f.readline()f.close()
fout.close()

        2.由于量化后的bias为int32的类型,而flatc将bias数据按照uint8的格式进行了转换,这里需要对json文件的bias再转换回int32类型,相当于json中bias区域四个字节转换为一个int32。详细讨论参考tensorflow github链接:https://github.com/tensorflow/tensorflow/issues/22279

        解析部分代码分为两个部分包括网络结构以及权重解析,方法相似。网络结构参数解析,部分代码如下:

from __future__ import division
import jsondef write_blob_info(p_file, inputs, input_shape):p_file.write(str(inputs) + ', ')p_file.write(str(3) + ', ')p_file.write(str(input_shape[3]) + ', ')p_file.write(str(input_shape[1]) + ', ')p_file.write(str(input_shape[2]) + ', ')with open("mobilenet_v1_1.0_224_quant.json",'r') as f:load_dict = json.load(f)param_file=open("mobilenet_v1_1.0_224_quant.proto",'w')tensors = load_dict["subgraphs"][0]["tensors"]
operators = load_dict["subgraphs"][0]["operators"]
inputs = load_dict["subgraphs"][0]["inputs"]
input_shape = tensors[inputs[0]]["shape"]param_file.write(str(len(operators) + 1) + ',\n')
write_blob_info(param_file, \inputs[0], \input_shape)
param_file.write('\n')for layer in operators:layer_name = layer["builtin_options_type"]operators_inputs  = layer["inputs"]input_len = len(operators_inputs)builtin_options = layer["builtin_options"]if layer_name == "Conv2DOptions":	#conv_2d, depthwiseconv_2dinput_shape  = tensors[operators_inputs[0]]["shape"]kernel_shape = tensors[operators_inputs[1]]["shape"]bias_shape   = tensors[operators_inputs[2]]["shape"]kernel_H = kernel_shape[1]kernel_W = kernel_shape[2]param_file.write(str(kernel_H) + ', ')param_file.write(str(kernel_W) + ', ')stride_H = builtin_options["stride_h"]stride_W = builtin_options["stride_w"]param_file.write(str(stride_H) + ', ')param_file.write(str(stride_W) + ', ')		dilation_W = builtin_options["dilation_w_factor"]dilation_H = builtin_options["dilation_h_factor"]param_file.write(str(dilation_H) + ', ')param_file.write(str(dilation_W) + ', ')bias_term = 1if input_len  < 3 or bias_shape[0] == 0:bias_term = 0param_file.write(str(bias_term) + ', ')bottom_zero_point = tensors[operators_inputs[0]]["quantization"]["zero_point"][0]param_file.write(str(bottom_zero_point) + ', ')write_blob_info(param_file, \operators_inputs[0], \input_shape)	#output_bloboperators_outputs = layer["outputs"]output_shape = tensors[operators_outputs[0]]["shape"]write_blob_info(param_file, \operators_outputs[0], \output_shape)param_file.write('\n')

(水平有限,如有问题及遗漏欢迎补充指出,互相学习。)

 


http://www.ppmy.cn/news/244103.html

相关文章

iOS 高级工程师面试必备

请简要介绍一下 MVC、MVVM 和 VIPER 架构模式。它们的优缺点分别是什么&#xff1f; MVC&#xff08;Model-View-Controller&#xff09;&#xff1a; MVC 是一种经典的软件架构模式&#xff0c;主要分为三个部分&#xff1a;Model&#xff08;模型&#xff09;、View&#xf…

CS1237 数据实测

目录 硬件 测试数据 结论 性能 转换时间 硬件 外部电阻电位器 模拟 传感器信号 测试数据 一下记录数据为ADC原始数据未经过软件滤波 跳动幅度图片DEV_FREQUENCY_10DEV_PGA_1751DEV_FREQUENCY_40DEV_PGA_11562DEV_FREQUENCY_640DEV_PGA_18243DEV_FREQUENCY_640DEV_PGA_644…

Linux内核sync流程

进程写文件时&#xff0c; 文件并没有真正写到存储设备&#xff0c; 而是写到了page cache中。 文件系统会定期把脏页写到存储设备&#xff0c; 进程也可以调用sync 这样的调用把脏页写回存储设设备。 数据结构 backing_dev_info 要理解这个结构体&#xff0c; 得从它需要解…

牛客网Linux错题六

1.有一个文件ip.txt&#xff0c;每行一条ip 记录&#xff0c;共若干行&#xff0c;已排好序&#xff0c;下面哪个命令可以实现“统计出现次数最多的前3个ip及其次数”&#xff1f;&#xff08;B&#xff09; A. uniq -c ip.txt B. uniq -c ip.txt | sort -nr | head -n 3 C.…

SpringBoot个人博客系统(含源码+数据库)

一、作品设计理念 个人博客系统是一个让个人可以通过互联网自由表达、交流和分享的平台&#xff0c;是个人展示自己思想、感受和经验的品牌。设计理念对于任何一个个人博客系统来说都非常重要&#xff0c;它直接影响到用户的使用体验和网站的整体感觉。 好的设计理念应该着眼于…

预见未来:超强元AI诞生,抓住这个机会,利用AI变现也变得更加容易

目录 一、引言 二、介绍 三、技术展现 四、元AI架构图展现 五、元AI变现技巧—商业版说明 六、后期规划 一、引言 如何利用AI变现已经成为了当今各个行业亟需解决的问题。随着人工智能技术的快速发展和普及&#xff0c;越来越多的企业开始将其应用于产品研发、销售流程优化、客…

SpringData整合ElasticSearch

一、环境搭建 1.Maven依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.3.6.RELEASE</version><relativePath/></parent><dependencies&g…

小米9es更新MIUI 11.0.3.0稳定版本,解决耗电问题

等待以已久的小米MIUI11系统终于迎来更新。今天早上有不少米粉称已经更新小米MIUI11.0.3.0稳定版本&#xff0c;本次更新大小为700M。 更新主要内容&#xff1a; 全面成熟和完善的全面屏设计&#xff0c;去除多余的视觉符号&#xff0c;采用精心设计大屏触控互交控件&#xff…