使用TensorFlow实现简化版 GoogLeNet 模型进行 MNIST 图像分类

embedded/2024/11/23 9:10:28/

        在本文中,我们将使用 TensorFlow 和 Keras 实现一个简化版的 GoogLeNet 模型来进行 MNIST 数据集的手写数字分类任务。GoogLeNet 采用了 Inception 模块,这使得它在处理图像数据时能更高效地提取特征。本教程将详细介绍如何在 MNIST 数据集上训练和测试这个模型。

项目结构

        我们的代码将分为两个部分:

  1. 训练部分 (train.py): 包含模型定义、数据加载、模型训练等。
  2. 测试部分 (test.py): 用于加载训练好的模型,并在测试集上评估其性能。

训练部分:train.py

1. 数据加载与预处理

        首先,我们需要加载 MNIST 数据集并进行预处理。预处理包括调整图像形状、归一化以及 One-Hot 编码标签。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categoricaldef load_and_preprocess_data():# 加载 MNIST 数据集(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理:将图像形状调整为 [28, 28, 1],并归一化到 [0, 1] 范围train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)) / 255.0test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0# One-Hot 编码标签train_labels = to_categorical(train_labels, 10)test_labels = to_categorical(test_labels, 10)return train_images, train_labels, test_images, test_labels

2. 创建简化版 GoogLeNet 模型

        接下来,我们定义一个简化版的 GoogLeNet 模型。该模型包括卷积层、Inception 模块和全连接层。

from tensorflow.keras import layers, modelsdef googlenet(input_shape=(28, 28, 1), num_classes=10):inputs = layers.Input(shape=input_shape)# 第一卷积层 + 池化层x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)x = layers.MaxPooling2D((2, 2))(x)# 第二卷积层 + 池化层x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)x = layers.MaxPooling2D((2, 2))(x)# 第三卷积层 + 池化层x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)x = layers.MaxPooling2D((2, 2))(x)# Inception 模块inception1 = layers.Conv2D(64, (1, 1), activation='relu', padding='same')(x)inception2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)inception3 = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)# 拼接 Inception 模块的输出x = layers.concatenate([inception1, inception2, inception3], axis=-1)# 全局平均池化层x = layers.GlobalAveragePooling2D()(x)# 全连接层x = layers.Dense(1024, activation='relu')(x)x = layers.Dropout(0.5)(x)  # Dropout 层减少过拟合outputs = layers.Dense(num_classes, activation='softmax')(x)  # 输出层,使用 softmax 激活函数进行多分类model = models.Model(inputs=inputs, outputs=outputs)return model

3. 模型训练

        定义好模型之后,我们使用 Adam 优化器和交叉熵损失函数来训练模型,并保存训练好的模型。

def train_model(model, train_images, train_labels, epochs=5, batch_size=64):# 训练模型history = model.fit(train_images, train_labels,epochs=epochs,batch_size=batch_size)return historydef save_model(model, filename='googlenet_mnist.h5'):model.save(filename)print(f"Model saved to {filename}")

4. 主程序

        最后,在主程序中,我们加载数据、创建模型并开始训练。

def main():train_images, train_labels, test_images, test_labels = load_and_preprocess_data()model = googlenet(input_shape=(28, 28, 1), num_classes=10)model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])train_model(model, train_images, train_labels, epochs=5, batch_size=64)save_model(model)if __name__ == '__main__':main()


测试部分:test.py

1. 加载训练好的模型

        在测试部分,我们将加载训练好的模型,并在测试集上进行评估。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categoricaldef load_and_preprocess_data():(_, _), (test_images, test_labels) = mnist.load_data()test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0test_labels = to_categorical(test_labels, 10)return test_images, test_labelsdef load_model(model_path='googlenet_mnist.h5'):model = tf.keras.models.load_model(model_path)return model

2. 评估模型

        我们通过 evaluate 方法评估模型的损失和准确率。

def evaluate_model(model, test_images, test_labels):test_loss, test_acc = model.evaluate(test_images, test_labels)print(f"Test accuracy: {test_acc * 100:.2f}%")return test_loss, test_acc

3. 显示预测结果

        使用 Matplotlib 可视化前几张图片的预测结果。

import matplotlib.pyplot as pltdef display_predictions(model, test_images, test_labels, num_images=6):predictions = model.predict(test_images[:num_images])fig, axes = plt.subplots(2, 3, figsize=(10, 6))axes = axes.flatten()for i in range(num_images):ax = axes[i]ax.imshow(test_images[i].reshape(28, 28), cmap='gray')ax.set_title(f"Pred: {tf.argmax(predictions[i]).numpy()} \n True: {tf.argmax(test_labels[i]).numpy()}")ax.axis('off')plt.tight_layout()plt.show()

4. 主程序

        在主程序中,我们加载模型,评估其性能,并显示预测结果。

def main():test_images, test_labels = load_and_preprocess_data()model = load_model('googlenet_mnist.h5')evaluate_model(model, test_images, test_labels)display_predictions(model, test_images, test_labels)if __name__ == '__main__':main()


总结

        本文介绍了如何使用 TensorFlow 实现简化版 GoogLeNet,并在 MNIST 数据集上进行训练和测试。我们将代码分为训练和测试两部分,分别处理数据预处理、模型训练与评估、结果展示等工作。

        通过使用 GoogLeNet 进行图像分类,我们不仅能够提高分类性能,还能了解 Inception 模块在图像处理中的强大能力。希望这篇博客能够帮助你更好地理解深度学习模型的训练与测试过程。

完整项目:GoogLeNet-TensorFlow: 使用TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/goog-le-net-tensor-flow

qxd-ljy/GoogLeNet-TensorFlow: 使用 TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-TensorFlow 


http://www.ppmy.cn/embedded/139815.html

相关文章

redis工程实战介绍(含面试题)

文章目录 redis单线程VS多线程面试题**redis是多线程还是单线程,为什么是单线程****聊聊redis的多线程特性和IO多路复用****io多路复用模型****redis如此快的原因** BigKey大批量插入数据测试数据key面试题海量数据里查询某一固定前缀的key如果生产上限值keys * ,fl…

egrep grep 区别

‌egrep 和 grep 的主要区别在于对正则表达式的支持。 -rwxr-xr-x 1 root root 28 Jan 29 2020 /bin/egrep -rwxr-xr-x 1 root root 199136 Jan 29 2020 /bin/grep 1e6ebb9dd094f774478f72727bdba0f5 /bin/grep ef55d1537377114cc24cdc398fbdd930 /bin/egrep 区别 gre…

js中new操作符具体都干了什么?

在JavaScript中,new操作符是一个用于创建对象实例的关键字,它背后的机制相当复杂,但以下是它执行的主要步骤: new操作符的工作原理: 创建一个全新的空对象:首先,JavaScript会创建一个全新的对象…

vue3 + elementPlus 日期时间选择器禁用未来及过去时间

<el-date-pickerv-model"form.jyTime"type"datetime"placeholder"请选择加油时间"format"YYYY/MM/DD HH:mm:ss"value-format"YYYY-MM-DD HH:mm:ss":disabled-date"disabledDate"/> 一、禁用未来时间 /** 时…

基础自动化系统的任务

基础自动化系统的任务主要包括实现自动控制、提高生产效率、减少人工干预等。以下是其具体任务的相关介绍&#xff1a; 实现自动控制 控制机器设备&#xff1a;基础自动化系统通过预设的程序和逻辑规则&#xff0c;对机器或设备进行自动控制和运行。执行特定任务&#xff1a;这…

java基础(一):JDK、JRE、JVM、类库等概念,java跨平台实现原理

目录 1、基本概念 2、程序运行过程 3、java跨平台原理 1、基本概念 JVM&#xff1a;虚拟机&#xff0c;真正运行java程序的地方 核心类库&#xff1a;java自己写好的程序&#xff0c;给程序员自己调用的&#xff0c;例如System.out.println()&#xff0c;调用的就是 核心…

leetcode105:从前序与中序遍历构建二叉树

给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示例 1: 输入: preorder [3,9,20,15,7], inorder [9,3,15,20,7] 输出: [3,9,20,null,null,15,7]示…

『Linux』 第四章 进程—— 进程状态讲解

目录 1.1.1 通过系统调用创建进程-fork初识 1.2 进程状态 1.2.1 Linux内核源代码怎么说 1.2.2 进程状态查看 1.2.3 Z(zombie)-僵尸进程 1.2.4 僵尸进程危害 1.2.5 孤儿进程 1.3 进程优先级 1.3.1 基本概念 1.3.2 查看系统进程 1.3.3 PRI and NI 1.3.4 PRI vs NI …