【机器学习:十五、神经网络的编译和训练】

ops/2025/1/17 20:22:54/

1. TensorFlow实现代码

TensorFlow 是深度学习中最为广泛使用的框架之一,提供了灵活的接口来构建、编译和训练神经网络。以下是实现神经网络的一个完整代码示例,以“手写数字识别”为例:

import tensorflow as tf
from tensorflow.keras import layers, models# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 构建模型
model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5)# 测试模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc}")

以上代码展示了从加载数据到模型训练和测试的完整流程,后续小节将分解具体步骤进行详解。


2. 编译 compile()

编译模型的重要性
model.compile()神经网络模型在 TensorFlow 中的关键步骤,用于指定优化器、损失函数和评估指标。编译后,模型才能够进行训练。其功能包括:

  • 定义优化器:决定模型如何更新权重(如 Adam、SGD)。
  • 设置损失函数:衡量预测值与真实值之间的误差。
  • 选择评估指标:训练过程中实时监控模型性能。

常用参数解释

model.compile(optimizer='adam',  # 指定优化器loss='sparse_categorical_crossentropy',  # 损失函数metrics=['accuracy'])  # 评估指标
  • optimizer:优化器可选用 SGD、RMSprop、Adam 等。Adam 适合大多数任务。
  • loss:根据任务选择合适的损失函数。例如分类任务用交叉熵,回归任务用均方误差。
  • metrics:常用指标包括准确率(accuracy)和均方误差(mse)。

3. 训练 fit()

fit() 是 TensorFlow 模型训练的核心方法,用于指定训练数据、批量大小、训练轮数等。

model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.2)

参数解释

  • x_trainy_train:训练数据及其对应标签。
  • batch_size:每次训练使用的数据样本数。较小的批量会增加训练时间,但收敛更稳定。
  • epochs:完整训练数据通过神经网络的次数。
  • validation_split:从训练数据中划分一定比例用于验证模型性能。

训练结果分析
fit() 会输出训练过程的损失值和评估指标(如准确率)。通过观察这些值的变化,可以判断模型是否过拟合或欠拟合。


4. 模型结构及代码

神经网络的结构设计直接影响模型性能。以下是经典网络的常见设计:

  • 输入层:用于接受数据。
  • 隐藏层:包含多个神经元,负责提取特征。
  • 输出层:根据任务设置输出类别或数值。

以 MNIST 分类为例

model = models.Sequential([layers.Flatten(input_shape=(28, 28)),  # 输入层layers.Dense(128, activation='relu'),  # 隐藏层layers.Dense(10, activation='softmax')  # 输出层
])

5. 算法步骤

训练神经网络的基本步骤如下:

  1. 初始化模型和参数。
  2. 数据预处理:归一化、数据增强等。
  3. 构建模型:选择适当的层数、神经元数和激活函数。
  4. 编译模型:定义损失函数和优化器。
  5. 模型训练:使用训练数据进行多轮迭代。
  6. 测试模型:用测试数据评估最终性能。

6. 损失函数和优化函数的数学公式

  • 损失函数:衡量预测值与真实值之间的差距。

    • 分类任务:CrossEntropy = -Σ(y_true * log(y_pred))
    • 回归任务:MSE = (1/n)Σ(y_true - y_pred)^2
  • 优化函数:通过梯度下降最小化损失函数。

    • 梯度下降公式:w_new = w_old - learning_rate * ∂L/∂w

7. 二元交叉熵损失函数:适用于二分类问题

对于二分类任务(如垃圾邮件检测),交叉熵损失函数是最常用的选择:

  • 数学公式:
    BinaryCrossEntropy = -[y * log(p) + (1-y) * log(1-p)]

  • TensorFlow 实现:

    loss = tf.keras.losses.BinaryCrossentropy()
    

8. 均方误差损失函数:适用于回归问题

均方误差(MSE)适用于预测连续数值:

  • 数学公式:
    MSE = (1/n)Σ(y_true - y_pred)^2

  • TensorFlow 实现:

    loss = tf.keras.losses.MeanSquaredError()
    

9. 总结

神经网络的编译和训练是深度学习的核心环节。通过选择合适的损失函数和优化器,结合数据的有效预处理,能够实现高效的模型训练与预测。TensorFlow 提供了丰富的接口和工具,使得开发者可以快速构建和调试神经网络应用。


http://www.ppmy.cn/ops/150906.html

相关文章

util层注入service

简介背景 在 Java 或 Spring 框架中,util 层通常用于存放工具类或辅助类,而 service 层则通常包含核心业务逻辑。在一些情况下,可能需要将 service 层注入到 util 层中,以便在工具类中调用某些业务逻辑。虽然这种做法并不是最常见…

Spring Boot 下的Swagger 3.0 与 Swagger 2.0 的详细对比

先说结论: Swgger 3.0 与Swagger 2.0 区别很大,Swagger3.0用了最新的注释实现更强大的功能,同时使得代码更优雅。 就个人而言,如果新项目推荐使用Swgger 3.0,对于工具而言新的一定比旧的好;对接于旧项目原…

梁山派入门指南2——滴答定时器位带操作按键输入(包括GPIO中断)

梁山派入门指南2——滴答定时器&位带操作&按键输入 1. 滴答定时器1.1 滴答定时器简介1.2 相关寄存器1.3 固件库函数 2. 位带操作2.1 位带操作介绍2.2 位带操作的优势2.3 支持位带操作的内存地址2.4 位带别名区地址的计算方式2.5 位带操作使用示例 3 按键输入3.1 独立按…

目标检测新视野 | YOLO、SSD与Faster R-CNN三大目标检测模型深度对比分析

目录 引言 YOLO系列 网络结构 多尺度检测 损失函数 关键特性 SSD 锚框设计 损失函数 关键特性 Faster R-CNN 区域建议网络(RPN) 两阶段检测器 损失函数 差异分析 共同特点 基于深度学习 目标框预测 损失函数优化 支持多类别检测 应…

idea上git log面板的使用

文章目录 各种颜色含义具体的文件的颜色标签颜色🏷️ 节点和路线 各种颜色含义 具体的文件的颜色 红色:表示还没有 git add 提交到暂存区绿色:表示已经 git add 过,但是从来没有 commit 过蓝色:表示文件有过改动 标…

C#轻松实现条形码二维码生成及识别

一、前言 大家好!我是付工。 今天给大家分享一下,如何基于C#来生成并识别条形码或者二维码。 二、ZXing.Net 实现二维码生成的库有很多,我们这里采用的是http://ZXing.Net。 ZXing是一个开放源码的,用Java实现的多种格式的一…

pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系

pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系 1. x.clone()示例: 2. x.detach()示例:使用场景: 3. torch.tensor(x).float()示例&#xff…

MYSQL的第一次作业

目录 前情提要 题目解析 连接并使用数据库 创建employees表 创建orders表 创建invoices表 ​查看建立的表 前情提要 需要下载mysql并进行配置,建议下载8.0.37,详情可见MySQL超详细安装配置教程(亲测有效)_mysql安装教程-CSDN博客 题目解析 …