【tensorflow】TF1.x保存.pb模型 解决模型越训练越大问题

news/2025/2/12 7:47:06/

【tensorflow】TF1.x保存.pb模型 解决模型越训练越大问题

  • 举例:模型定义如下
  • 模型保存代码

  在上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法中,保存的是模型训练过程中所有的参数,而且训练越久,最终保存的模型就越大。我的模型只有几千参数,可是最终保存的文件有1GB。。。。

  但是其实我只想要保存参数去部署模型,然后预测。网上有一些解决方案但都不是我需要的,因为我要用Java部署模型,python这里必须要用builder.add_meta_graph_and_variables来保存参数。以下是解决方案:

举例:模型定义如下

# 定义模型
with tf.name_scope("Model"):"""MLP"""# 13个连续特征数据(13列)x = tf.placeholder(tf.float32, [None,13], name='X') # 正则化x_norm = tf.layers.batch_normalization(inputs=x)# 定义一层Densedense_1 = tf.layers.Dense(64, activation="relu")(x_norm)"""EMBED"""# 离散输入y = tf.placeholder(tf.int32, [None,2], name='Y')# 创建嵌入矩阵变量embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))# 使用tf.nn.embedding_lookup函数获取嵌入向量embeddings = tf.nn.embedding_lookup(embedding_matrix, y)# 创建 LSTM 层lstm_cell = tf.nn.rnn_cell.LSTMCell(64)# 初始化 LSTM 单元状态initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)# 将输入数据传递给 LSTM 层lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)# 定义一层Densedense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])"""MERGE"""combined = tf.concat([dense_1, dense_2], axis = -1)pred = tf.layers.Dense(2, activation="relu")(combined)pred = tf.layers.Dense(1, activation="linear", name='P')(pred)z = tf.placeholder(tf.float32, [None, 1], name='Z')

  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。

模型保存代码

  这里的保存方法建议对比上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法来看。

import tensorflow as tf
from tensorflow import saved_model as sm
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.core.framework import graph_pb2def get_node_names(name_list, nodes_list):name_list.extend([n.name.split(":")[0] for _, n in nodes_list.items() if n.name.split(":")[0] != ''])# 创建 Saver 对象
saver = tf.train.Saver()# 生成会话,训练STEPS轮
with tf.Session() as sess:# 初始化参数sess.run(tf.global_variables_initializer())...... # 模型训练逻辑# 准备存储模型path = 'pb_model/'# 创建 Saver 对象,用于保存和加载模型的变量pb_saver = tf.train.Saver(var_list=None)# 将 Saver 对象转换为 SaverDef 对象saver_def = pb_saver.as_saver_def()# 从会话的图定义中提取包含恢复操作的子图saver_def_ingraph = tf.graph_util.extract_sub_graph(sess.graph.as_graph_def(), [saver_def.restore_op_name])# 构建需要在新会话中恢复的变量的 TensorInfo protobuf# 自定义 根据自己的模型来写inputs = {'x' : sm.utils.build_tensor_info(x),'y' : sm.utils.build_tensor_info(y),'z' : sm.utils.build_tensor_info(z)}outputs = {'p' : sm.utils.build_tensor_info(pred)}# 获取节点的名称input_node_names = []get_node_names(input_node_names, inputs)output_node_names = []get_node_names(output_node_names, outputs)# 获取当前会话的图定义input_graph_def = sess.graph.as_graph_def()# 定义需要应用的图转换操作列表transforms = ['add_default_attributes','fold_constants(ignore_errors=true)','fold_batch_norms','fold_old_batch_norms','sort_by_execution_order','strip_unused_nodes']# 应用图转换操作,并获取优化后的图定义opt_graph_def = TransformGraph(input_graph_def,input_node_names,output_node_names,transforms)# 创建新的默认图并导入优化后的图定义with tf.Graph().as_default() as graph:all_names = set([node.name for node in opt_graph_def.node])saver_def_ingraph_nodes = [node for node in saver_def_ingraph.node if not node.name in all_names]merged_graph_def = graph_pb2.GraphDef()merged_graph_def.node.extend(opt_graph_def.node)merged_graph_def.node.extend(saver_def_ingraph_nodes)# 导入合并后的图定义到新的默认图中tf.graph_util.import_graph_def(merged_graph_def, name="")builder = sm.builder.SavedModelBuilder(path)# 将 graph 和变量等信息写入 MetaGraphDef protobuf# 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.SERVING],signature_def_map={sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},saver=pb_saver,main_op=tf.local_variables_initializer())# 将 MetaGraphDef 写入磁盘builder.save()

  这样之后你会发现模型的大小从GB锐减到几十KB。


http://www.ppmy.cn/news/736443.html

相关文章

FPGA DVB-S2 FEC 信道编码 BCH编码器 LDPC编码器 交织器 IP core

基于FPGA的DVB-S2发射机IP core,含BCH编码IP、LDPC编码IP、交织IP。 (1)支持DVB-S2标准中BCH码全部编码样式; 长帧(64800),Nbch:16200、21600 、25920、32400、38880、43200、48600、51840、54000、57600、 58320; 短帧(16200),…

【数据分析案例】游戏付费用户RFM分析案例

前言: 该案例由随机生成数据模拟半年时间内,对游戏用户充值金额情况进行用户价值分层 数据特征: 所有数据均为随机生成数据,一共有三个数据特征,分别是 r_datatime:充值时间,时间截格式&…

17. Redis sentinel机制-实现高可用

Redis 的Sentinel机制是Redis官方提供的保证Redis高可用的工具,Redis Sentinel 采用Raft 分布式一致性算法来保证Redis 的高可用. 1. Sentinel 机制 1.1 Sentinel 主要功能 Monitoring(监控): Sentinel 会不断监测主服务器和从服务器是否正常运行Notification(通…

提前做好网络安全分析,运维真轻松(二)

背景 某汽车总部已部署NetInside流量分析系统,使用流量分析系统提供实时和历史原始流量。汽车配件电子图册系统是某汽车集团的重要业务系统。本次分析重点针对汽车配件电子图册系统进行预见性分析,以供安全取证、性能分析、网络质量监测以及深层网络分析…

springBoot整合redis启动报错:event executor terminated

springBoot整合redis启动报错:java.util.concurrent.RejectedExecutionException: event executor terminated 背景 redis一主两从三哨兵部署模式搭建完成后,需要整合springCloud项目,替换掉之前的redis单机模式,更改nacos配置中…

今天打开个税APP,我直接人麻了!

点击上方“码农突围”,马上关注 这里是码农充电第一站,回复“666”,获取一份专属大礼包 真爱,请设置“星标”或点个“在看 这是【码农突围】的第 432 篇原创分享 作者 l 突围的鱼 来源 l 码农突围(ID:smart…

05. Redis 环境搭建-高可用集群(HA)

在生产环境中,Redis 架构使用最多的就是Sentinel主从架构, 因为单点容易产生故障, 分片集群又过于复杂. 笔者尝试在一台服务器上搭建一个一主两从, 三个哨兵监听的Redis 集群架构。由于哨兵也可能发生单点故障,所以笔者也使用了三…

Elasticsearch(六)--ES文档的操作(中)---修改文档

一、前言 上篇文章我们了解了ES的插入和批量插入文档的操作,分别通过ES的kibana客户端以及Java高级Rest客户端进行学习,那么本篇则进入到对文档的修改操作,同新增文档,也有更新单条文档和批量更新文档操作,但还多出一…