TensorFlow 示例摄氏度到华氏度的转换(二)

embedded/2025/2/3 19:08:50/

这是一个完整的神经网络实现,用于将摄氏度转换为华氏度。下面,我会逐步描述各个步骤,并提供完整代码。

1. 数据准备与预处理

在这部分,我们准备了摄氏度(features)与对应的华氏度(labels)数据。这些数据将作为输入和输出提供给神经网络。我们还需要将输入数据的形状调整为二维数组 (N, 1),因为 TensorFlow 要求输入数据的形状为二维。
这些数据表示摄氏度到华氏度的转换公式:华氏度 = 摄氏度 × 9/5 + 32

features = np.array([-50, -40, -10, 0, 8, 22, 35, 45, 55, 65, 75, 95], dtype=float)
labels = np.array([-58.0, -40.0, 14.0, 32.0, 46.4, 71.6, 95.0, 113.0, 131.0, 149.0, 167.0, 203.0], dtype=float)

2. 构建模型

我们使用 tf.keras.Sequential 创建一个简单的神经网络模型。该模型包含一个层,即 Dense 层,它表示一个全连接层。这个层的输入是一个数值(摄氏度),输出一个数值(华氏度)。

layer = tf.keras.layers.Dense(units=1, input_shape=[1])
model = tf.keras.Sequential([layer])

3. 编译模型

在编译模型时,我们指定优化器为 Adam,并设置学习率为 0.1。损失函数使用 mean_squared_error,因为我们是进行回归任务。

model.compile(optimizer=tf.keras.optimizers.Adam(0.1), loss='mean_squared_error')

4. 训练模型

我们使用 fit() 方法进行训练。设定了 epochs=1000,即训练模型 1000 次。

history = model.fit(features, labels, epochs=1000, verbose=1)

5. 评估模型

训练完成后,我们可以用 model.predict() 方法来做预测。输入一个摄氏度(例如 38.1°C),模型会输出预测的华氏度。

print(model.predict(np.array([[38.1]])))  # 例如:预测38.1°C对应的华氏度

6. 模型应用与预测

我们通过调用 model.predict(np.array([[38.1]])) 来预测给定输入摄氏度对应的华氏度。

prediction = model.predict(np.array([[10]]))
print(f"10°C 对应的华氏度是: {prediction[0][0]}")

7. 保存与加载模型

可以使用 model.save('model.h5') 来保存模型,使用 tf.keras.models.load_model('model.h5') 来加载保存的模型。

model.save('temperature_model.h5')  # 保存模型
# 加载模型
loaded_model = tf.keras.models.load_model('temperature_model.h5')

8. 可视化损失变化

训练过程中,损失会随着训练轮数的增加而变化。我们通过 history.history['loss'] 获取训练过程中的损失变化,并用 matplotlib 可视化出来。

plt.plot(history.history['loss'])
plt.title("Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

9. 完整代码

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt# 1. 数据准备与预处理
# 温度数据:摄氏度到华氏度的转换
features = np.array([-50, -40, -10, 0, 8, 22, 35, 45, 55, 65, 75, 95], dtype=float)
labels = np.array([-58.0, -40.0, 14.0, 32.0, 46.4, 71.6, 95.0, 113.0, 131.0, 149.0, 167.0, 203.0], dtype=float)# 调整输入数据形状为二维数组 (N, 1)
features = features.reshape(-1, 1)# 2. 构建模型
layer = tf.keras.layers.Dense(units=1, input_shape=[1])  # 输入一个值,输出一个值
model = tf.keras.Sequential([layer])# 3. 编译模型
model.compile(optimizer=tf.keras.optimizers.Adam(0.1), loss='mean_squared_error')# 4. 训练模型
history = model.fit(features, labels, epochs=1000, verbose=1)# 5. 评估模型
# 预测新数据
print(model.predict(np.array([[38.1]])))  # 例如:预测38.1°C对应的华氏度
print(layer.get_weights())  # 查看训练后的模型权重# 6. 模型应用与预测
# 例如:用模型预测其他摄氏度的华氏度值
# prediction = model.predict(np.array([[10]]))
# print(f"10°C 对应的华氏度是: {prediction[0][0]}")# 7. 保存与加载模型
# 保存模型
# model.save('temperature_model.h5')
# 加载模型
# loaded_model = tf.keras.models.load_model('temperature_model.h5')# 8. 可视化损失变化
plt.plot(history.history['loss'])
plt.title("Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

10. 总结

  1. 准备数据:摄氏度和对应的华氏度数据。
  2. 构建模型:使用 Keras 创建简单的神经网络模型。
  3. 编译模型:选择优化器、损失函数并设置学习率。
  4. 训练模型:用数据训练模型,让模型学习摄氏度到华氏度的关系。
  5. 评估与预测:用训练好的模型预测新的摄氏度对应的华氏度。
  6. 保存和加载模型:保存训练好的模型以便以后使用。
  7. 可视化损失变化:观察训练过程中的损失值,帮助评估模型效果。

每个步骤都有其目的和作用,希望通过这个逐步讲解,能帮助你更清楚地理解如何用神经网络进行温度转换任务!


11. 视频链接

 https://v.douyin.com/ifnXmRHG/ 复制此链接,打开Dou音搜索,直接观看视频!


http://www.ppmy.cn/embedded/159246.html

相关文章

32. C 语言 安全函数( _s 尾缀)

本章目录 前言什么是安全函数?安全函数的特点主要的安全函数1. 字符串操作安全函数2. 格式化输出安全函数3. 内存操作安全函数4. 其他常用安全函数 安全函数实例示例 1:strcpy_s 和 strcat_s示例 2:memcpy_s示例 3:strtok_s 总结 …

【论文阅读笔记】“万字”关于深度学习的图像和视频阴影检测、去除和生成的综述笔记 | 2024.9.3

论文“Unveiling Deep Shadows: A Survey on Image and Video Shadow Detection, Removal, and Generation in the Era of Deep Learning”内容包含第1节简介、第2-5节分别对阴影检测、实例阴影检测、阴影去除和阴影生成进行了全面的综述。第6节深入讨论了阴影分析&#xff0…

每日 Java 面试题分享【第 16 天】

欢迎来到每日 Java 面试题分享栏目! 订阅专栏,不错过每一天的练习 今日分享 3 道面试题目! 评论区复述一遍印象更深刻噢~ 目录 问题一:Java 运行时异常和编译时异常之间的区别是什么?问题二:什么是 Jav…

2181、合并零之间的节点

2181、[中等] 合并零之间的节点 1、问题描述: 给你一个链表的头节点 head ,该链表包含由 0 分隔开的一连串整数。链表的 开端 和 末尾 的节点都满足 Node.val 0 。 对于每两个相邻的 0 ,请你将它们之间的所有节点合并成一个节点&#xff…

minimind - 从零开始训练小型语言模型

大语言模型(LLM)领域,如 GPT、LLaMA、GLM 等,虽然它们效果惊艳, 但动辄10 Bilion庞大的模型参数个人设备显存远不够训练,甚至推理困难。 几乎所有人都不会只满足于用Lora等方案fine-tuing大模型学会一些新的…

【开源免费】基于SpringBoot+Vue.JS体育馆管理系统(JAVA毕业设计)

本文项目编号 T 165 ,文末自助获取源码 \color{red}{T165,文末自助获取源码} T165,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

LeetCode435周赛T2贪心

题目描述 给你一个由字符 N、S、E 和 W 组成的字符串 s,其中 s[i] 表示在无限网格中的移动操作: N:向北移动 1 个单位。S:向南移动 1 个单位。E:向东移动 1 个单位。W:向西移动 1 个单位。 初始时&#…

ubuntu 下使用deepseek

安装Ollama sudo snap install ollama 执行 ollama run deepseek-coder 然后进行等待。。。