【3】模型相关函数及构建二维线性模型

news/2024/11/24 6:34:41/

1 保存和恢复模型

1.1保存模型

tf.train.Saver()函数可以建立一个saver对象,然后在session中调用save即可将模型保存起来。

 

# 导入tensorflow类库
import tensorflow as tfv1 = tf.Variable(tf.constant([[5.0, 6.0], [7.0, 7.0]], shape=[2, 2]), name="m1")
v2 = tf.Variable(tf.constant([[4.0, 6.0], [7.0, 8.0]], shape=[2, 2]), name="m2")
result = v1 + v2
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)saver.restore(sess, "model/model.ckpt")print(sess.run(result))

会产生四个文件:

1.2 载入模型

saver.restore()

 通过tf.train.import_meta_graph直接加载计算图,获得模型的输出结果。

import tensorflow as tf# 通过tf.train.import_meta_graph,直接加载持久化的图
saver = tf.train.import_meta_graph("model/model.ckpt.meta")
with tf.Session()as sess:# saver.restore在当前会话中还原模型saver.restore(sess, 'model/model.ckpt')print("m1", sess.run(tf.get_default_graph().get_tensor_by_name('m1:0')))print("m2", sess.run(tf.get_default_graph().get_tensor_by_name('m2:0')))print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

2 使用模型预测

使用模型预测步骤:

import tensorflow.compat.v1 as tf# 加载计算图,不加载参数
saver = tf.train.import_meta_graph('predict/predict_model.ckpt.meta')
with tf.Session()as sess:# 加载x节点input_x = sess.graph.get_tensor_by_name('x:0')# 加载y节点input_y = sess.graph.get_tensor_by_name('y:0')# 获得矩阵相乘操作mul_result = sess.graph.get_tensor_by_name('mul_result:0')# 向节点喂入数据,获得输出结果result = sess.run(mul_result, feed_dict={input_x: [[2, 3, 4], [2, 3, 4]], input_y: [[1, 2, 3], [3, 5, 5]]})print("矩阵乘法结果:", result)

3 构建二维线性拟合模型

步骤:

 3.1准备数据

随机产生数据:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt#准备数据
train_x=np.linspace(-1,1,100)
train_y=2*train_x + np.random.randn(*train_x.shape)*0.5
plt.plot(train_x,train_y,'ro',label="原始数据集")
plt.show()

 3.2 搭建模型

###搭建模型
x=tf.placeholder(dtype=tf.float32)
y=tf.placeholder(dtype=tf.float32)
w=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name='bias')
z=tf.multiply(x,w)+b

 3.3 反向传播

 

###反向传播
cost=tf.reduce_mean(tf.square(y-z))#均方误差
learning_rate=0.05
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)#优化器采用随机梯度下降算法

3.4 迭代训练模型

###迭代训练模型
training_epochs=100
display_step=10
saver=tf.train.Saver()
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(training_epochs):  #向模型中输入数据for (x_data,y_data) in zip(train_x,train_y):sess.run(optimizer,feed_dict={x:x_data,y:y_data})if epoch % display_step ==0:loss=sess.run(cost,feed_dict={x:x_data,y:y_data})print('Epoch:',epoch+1,'cost:',loss)saver.save(sess,save_path="linear/linear.ckpt")#保存模型

 

3.5 模型预测

###模型预测
with tf.Session() as sess:saver.restore(sess,"linear/linear.ckpt")print("模型的预测值为:",sess.run(z,feed_dict={x:0.6}))
模型的预测值为: [1.2711029]


http://www.ppmy.cn/news/78250.html

相关文章

CSDN MD编辑器跳转方法及字体格式

一、点击关键语句跳转指定位置 在CSDN写文章的时候,写的文章过长往往会让读者很难找到自己想看的部分,这时候有个 跳转到指定位置功能 就非常的便利。CSDN在MD编辑器上(富文本编辑器只有一种)就提供了两种跳转到指定位置的方法: 一、目录跳转…

使用 Kafka Assistant,为您的开发加速

简要介绍 快速查看所有 Kafka 集群,包括Brokers、Topics和Consumers支持各种认证模式:PLAINTEXT、SASL_PLAINTEXT、SSL、SASL_SSL对Kafka集群进行健康检查查看分区中的消息内容并添加新消息查看消费者订阅了哪些主题,以及分区被分配给了哪些…

如何实现自我管理?

作者 | Stefan Wolpers 自我管理是一个组织实现业务敏捷性的重要组成部分,还是一种很好的文化转变,例如,让团队快乐并吸引新人才。 虽然很多人,尤其是管理层的人,对这一概念持怀疑态度,但我相信&#xff…

月薪10k和月薪25k的软件测试人员有什么区别?看完你就不会再迷茫了

了解软件测试这行的人都清楚,功能测试的天花板可能也就15k左右,而自动化的起点就在15k左右,当然两个岗位需要掌握的技能肯定是不一样的。 如果刚入门学习完软件测试,那么基本薪资会在7-8k左右,这个薪资不太高主要是因…

汇编十一、汇编实现外部中断

1、实现目的 (1)实现8颗LED灯呈流水灯依次被点亮;静态数码管通过按键按下,显示数值发生改变,通过按键依次显示0-9。 (2)按键检测采用外部中断检测。 2、原理图及硬件连接 2.1、LED灯 (1)51单片机P1端口接八个共阴极LED灯,即I…

面试经验小结

1、为什么C有重载而C语言没有? C的编译过程中,将函数名后面的数据类型也加入到了编译阶段。 2、用异或完成两个数的数值交换。 x^y; y^x; x^y; 3、数组指针与指针数组;函数指针与指针函数 4、segment …

一体化医学影像平台PACS源码,影像存档与传输系统源码

PACS影像存档与传输系统源码 PACS即影像存档与传输系统,是医学影像、数字化图像技术、计算机技术和网络通讯技术相结合的产物,是处理各种医学影像信息的采集、存储、报告、输出、管理、查询的计算机应用程序。 是基于DICOM标准的医学影像管理系统&…

研发项目工时统计工具哪个好?9大工时管理系统盘点

工时管理是项目型企业的重要需求,特别是在人力成本占比较高的行业,如软件开发、设计咨询、会计律师等。工时管理可以帮助企业核算项目人工成本,控制成本投入,提高项目利润,客观考核员工绩效,优化资源分配等…