3 tensorflow构建模型详解

news/2025/2/20 21:31:00/

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

2、密集层

在上一篇我们创建了预测价格模型,其中代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

意思是构建了一个模型,里面添加了一层神经网络,只有一个神经元。(Sequential是顺序的意思,Dense是密集层。)

密集层(也叫全连接层),在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

3、西瓜费用预测模型详解

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

下面是代码详解:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

4、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])


 

         


http://www.ppmy.cn/news/1187141.html

相关文章

【电源专题】什么是POE?

PoE(Power over Ethernet)是指通过网线传输电力的一种技术,借助现有以太网通过网线同时为IP终端设备(如:IP电话、AP、IP摄像头等)进行数据传输和供电。 PoE又被称为基于局域网的供电系统(Power over LAN,简称PoL)或有源以太网( Active Ethernet),有时也被简称为以太网供…

[linux] Syntax error: “(“ unexpected错误,sh报错

参考:Shell 数组 Syntax error ( unexpected_.sh 使用数组报错-CSDN博客 #!/bin/bash arr(a) echo ${arr[0]}sh test.sh执行脚本的时候,报错:Syntax error: "(" unexpected错误。 而使用下面这种方式执行,则不会报错 …

14、SpringCloud -- WebSocket 实时通知用户

目录 实时通知用户需求:代码:前端:后端:WebSocket创建 websocket-server 服务添加依赖:配置 yml 和 启动类:前端:后端代码:注意:测试:总结:实时通知用户 需求: 用户订单秒杀成功之后,对用户进行秒杀成功通知。 弹出个提示框来提示。 代码: 前端:

第五章 I/O管理 六、I/O核心子系统

目录 一、核心子系统 1、I/O调度 2、设备保护 二、假脱机技术 1、脱机: 2、假脱机(SPOOLing技术): 3、应用: 1.独占式设备: 2.共享设备: 4、共享打印机原理分析 三、总结 一、核心子系…

6.scala辅助构造器与为构造函数提供默认值(一)

概述 本文主要说明: 辅助构造器 与 为构造函数提供默认值 的使用 辅助构造器为构造函数提供默认值 相关链接 阅读之前,可以浏览一下 scala相关文章 辅助构造器 可以通过定义名为this的方法来定义辅助Scala类构造函数。只有几个规则需要了解: 每个辅助…

C#中切换产品配方数据

首先,你需要一个数据库来存储配方数据。你可以使用SQL Server、MySQL、SQLite或其他关系型数据库,具体取决于你的需求。 以下是一个基本的步骤: 引用数据库库: 首先,你需要引用适当的数据库库,例如Entity …

轻量封装WebGPU渲染系统示例<8>- 渲染器基本场景管理(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/main/src/voxgpu/sample/RSceneTest.ts 此示例渲染系统实现的特性: 1. 用户态与系统态隔离。 2. 高频调用与低频调用隔离。 3. 面向用户的易用性封装。 4. 渲染数据和渲染机制分离。 5. 用户操作和渲…

DHorse改用fabric8的SDK与k8s集群交互

现状 在dhorse 1.4.0版本之前&#xff0c;一直使用k8s官方提供的sdk与k8s集群交互&#xff0c;官方sdk的Maven坐标如下&#xff1a; <dependency><groupId>io.kubernetes</groupId><artifactId>client-java</artifactId><version>18.0.0…