python3+TensorFlow 2.x(三)手写数字识别

server/2025/2/2 23:24:11/

目录

代码实现

模型解析:

1、加载 MNIST 数据集:

2、数据预处理:

3、构建神经网络模型:

4、编译模型:

5、训练模型:

6、评估模型:

7、预测和可视化结果:

输出结果:

总结:


代码实现

TensorFlow 2.x 实现手写数字识别(MNIST 数据集)。MNIST 数据集包含了 28x28 像素的手写数字图像,任务是将这些图像分类为 10 个类别(0-9) 

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt# 1. 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()# 2. 数据预处理:归一化和改变形状
train_images = train_images / 255.0  # 将图像像素值归一化到 [0, 1]
test_images = test_images / 255.0# 调整形状,使得每张图片的维度是 [28, 28, 1],因为模型需要3D输入
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))# 3. 构建神经网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 10类分类问题
])# 4. 编译模型:选择优化器、损失函数和评价指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 因为标签是整数,所以使用 sparse_categorical_crossentropymetrics=['accuracy'])# 5. 训练模型
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))# 6. 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")# 7. 可视化训练过程中的损失和准确率变化
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 8. 使用模型进行预测
predictions = model.predict(test_images)# 显示一些预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted Label: {predictions[i].argmax()}, Actual Label: {test_labels[i]}")plt.show()

模型解析:

1、加载 MNIST 数据集:

使用 tf.keras.datasets.mnist.load_data() 函数来加载 MNIST 数据集。返回的数据包括训练集和测试集。训练集有 60,000 张图像,测试集有 10,000 张图像。

2、数据预处理:

将图像的像素值从 [0, 255] 归一化到 [0, 1],使每个像素的值在 0 到 1 之间,提升模型的训练效果。将每张图像的形状调整为 (28, 28, 1),即每个图像是 28x28 的灰度图像。

3、构建神经网络模型:

使用卷积神经网络(CNN)构建模型:Conv2D 层用于提取图像的特征,使用了 ReLU 激活函数。MaxPooling2D 层用于下采样,减少计算量。Flatten 层将卷积层的输出展平,进入全连接层。Dense 层用于输出分类结果,其中最后一层使用了 softmax 激活函数,将模型的输出转换为 10 类的概率分布。

4、编译模型:

使用 adam 优化器,sparse_categorical_crossentropy 作为损失函数(适用于类别标签是整数的情况),并使用 accuracy 作为评价指标。

5、训练模型:

使用 model.fit 训练模型,设置了 5 个 epoch,使用训练集进行训练,并验证模型在测试集上的表现。

6、评估模型:

使用 model.evaluate 在测试集上评估模型的准确性。并可视化训练过程中的损失和准确率变化:使用 matplotlib 绘制训练过程中的损失和准确率变化曲线,查看模型的学习进度。

7、预测和可视化结果

使用训练好的模型对测试集进行预测,展示一些预测结果,并与真实标签进行对比。

输出结果

训练和验证准确率:随着训练的进行,准确率应该逐渐提高。
测试准确率:训练完成后,模型在测试集上的准确率会显示出来,通常可以达到 98% 以上。
预测图像:展示一些手写数字图像,标注预测的标签和实际标签。

预测可视化展示

总结:

该模型使用了卷积层、池化层以及全连接层,在 MNIST 数据集上训练,最终达到了很好的分类效果。你可以调整模型的超参数(例如卷积层的数量、神经元的数量等)以提高性能。


http://www.ppmy.cn/server/164458.html

相关文章

论文笔记(六十三)Understanding Diffusion Models: A Unified Perspective(五)

Understanding Diffusion Models: A Unified Perspective(五) 文章概括基于得分的生成模型(Score-based Generative Models) 文章概括 引用: article{luo2022understanding,title{Understanding diffusion models: A…

AI编程工具使用技巧:在Visual Studio Code中高效利用阿里云通义灵码

AI编程工具使用技巧:在Visual Studio Code中高效利用阿里云通义灵码 前言一、通义灵码介绍1.1 通义灵码简介1.2 主要功能1.3 版本选择1.4 支持环境 二、Visual Studio Code介绍1.1 VS Code简介1.2 主要特点 三、安装VsCode3.1下载VsCode3.2.安装VsCode3.3 打开VsCod…

OpenEuler学习笔记(十五):在OpenEuler上搭建Java运行环境

一、在OpenEuler上搭建Java运行环境 在OpenEuler上搭建Java运行环境可以通过以下几种常见方式,下面分别介绍基于包管理器安装OpenJDK和手动安装Oracle JDK的步骤。 使用包管理器安装OpenJDK OpenJDK是Java开发工具包的开源实现,在OpenEuler上可以方便…

《大数据时代“快刀”:Flink实时数据处理框架优势全解析》

在数字化浪潮中,数据呈爆发式增长,实时数据处理的重要性愈发凸显。从金融交易的实时风险监控,到电商平台的用户行为分析,各行业都急需能快速处理海量数据的工具。Flink作为一款开源的分布式流处理框架,在这一领域崭露头…

Vue 封装http 请求

封装message 提示 Message.js import { ElMessage } from "element-plus";const showMessage (msg,callback,type)>{ElMessage({message: msg,type: type,duration: 3000,onClose:()>{if (callback) {callback();}}}); }const message {error: (msg,…

c++面试:类定义为什么可以放到头文件中

这个问题是刚了解预编译的时候产生的疑惑。 声明是指向编译器告知某个变量、函数或类的存在及其类型,但并不分配实际的存储空间。声明的主要目的是让编译器知道如何解析程序中的符号引用。定义不仅告诉编译器实体的存在,还会为该实体分配存储空间&#…

XML DOM - 导航节点

可通过使用节点间的关系对节点进行导航。 导航 DOM 节点 通过节点间的关系访问节点树中的节点,通常称为导航节点("navigating nodes")。 在 XML DOM 中,节点的关系被定义为节点的属性: parentNodechildNo…

Unity实现按键设置功能代码

一、前言 最近在学习unity2D,想做一个横版过关游戏,需要按键设置功能,让用户可以自定义方向键与攻击键等。 自己写了一个,总结如下。 二、界面效果图 这个是一个csv文件,准备第一列是中文按键说明,第二列…