python3+TensorFlow 2.x(三)手写数字识别

ops/2025/2/2 23:43:14/

目录

代码实现

模型解析:

1、加载 MNIST 数据集:

2、数据预处理:

3、构建神经网络模型:

4、编译模型:

5、训练模型:

6、评估模型:

7、预测和可视化结果:

输出结果:

总结:


代码实现

TensorFlow 2.x 实现手写数字识别(MNIST 数据集)。MNIST 数据集包含了 28x28 像素的手写数字图像,任务是将这些图像分类为 10 个类别(0-9) 

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt# 1. 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()# 2. 数据预处理:归一化和改变形状
train_images = train_images / 255.0  # 将图像像素值归一化到 [0, 1]
test_images = test_images / 255.0# 调整形状,使得每张图片的维度是 [28, 28, 1],因为模型需要3D输入
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))# 3. 构建神经网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 10类分类问题
])# 4. 编译模型:选择优化器、损失函数和评价指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 因为标签是整数,所以使用 sparse_categorical_crossentropymetrics=['accuracy'])# 5. 训练模型
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))# 6. 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")# 7. 可视化训练过程中的损失和准确率变化
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 8. 使用模型进行预测
predictions = model.predict(test_images)# 显示一些预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted Label: {predictions[i].argmax()}, Actual Label: {test_labels[i]}")plt.show()

模型解析:

1、加载 MNIST 数据集:

使用 tf.keras.datasets.mnist.load_data() 函数来加载 MNIST 数据集。返回的数据包括训练集和测试集。训练集有 60,000 张图像,测试集有 10,000 张图像。

2、数据预处理:

将图像的像素值从 [0, 255] 归一化到 [0, 1],使每个像素的值在 0 到 1 之间,提升模型的训练效果。将每张图像的形状调整为 (28, 28, 1),即每个图像是 28x28 的灰度图像。

3、构建神经网络模型:

使用卷积神经网络(CNN)构建模型:Conv2D 层用于提取图像的特征,使用了 ReLU 激活函数。MaxPooling2D 层用于下采样,减少计算量。Flatten 层将卷积层的输出展平,进入全连接层。Dense 层用于输出分类结果,其中最后一层使用了 softmax 激活函数,将模型的输出转换为 10 类的概率分布。

4、编译模型:

使用 adam 优化器,sparse_categorical_crossentropy 作为损失函数(适用于类别标签是整数的情况),并使用 accuracy 作为评价指标。

5、训练模型:

使用 model.fit 训练模型,设置了 5 个 epoch,使用训练集进行训练,并验证模型在测试集上的表现。

6、评估模型:

使用 model.evaluate 在测试集上评估模型的准确性。并可视化训练过程中的损失和准确率变化:使用 matplotlib 绘制训练过程中的损失和准确率变化曲线,查看模型的学习进度。

7、预测和可视化结果

使用训练好的模型对测试集进行预测,展示一些预测结果,并与真实标签进行对比。

输出结果

训练和验证准确率:随着训练的进行,准确率应该逐渐提高。
测试准确率:训练完成后,模型在测试集上的准确率会显示出来,通常可以达到 98% 以上。
预测图像:展示一些手写数字图像,标注预测的标签和实际标签。

预测可视化展示

总结:

该模型使用了卷积层、池化层以及全连接层,在 MNIST 数据集上训练,最终达到了很好的分类效果。你可以调整模型的超参数(例如卷积层的数量、神经元的数量等)以提高性能。


http://www.ppmy.cn/ops/155155.html

相关文章

数据挖掘常用算法

文章目录 基于机器学习~~线性/逻辑回归~~树模型~~贝叶斯~~~~聚类~~集成算法神经网络~~支持向量机~~~~降维算法~~ 基于机器学习 线性/逻辑回归 类似单层神经网络 yk*xb 树模型 优点 可以做可视化分析速度快结果稳定 依赖前期对业务和数据的理解 贝叶斯 贝叶斯依赖先验概…

【10】如何辨别IOS AP镜像

1.概述 本文将针对思科的IOS AP来判断AP的镜像,通常我们通过直接的AP名称,很难判断该AP具体的软件版本,包括这个AP镜像是给什么型号的AP使用的,本文将针对这些内容进行介绍。 2.AP镜像了解 在思科官方下载瘦AP的镜像,一般都是15.3...,这个需要下载完毕,解压,可以看到…

解决运行npm时报错

在运行一个Vue项目时报错,产生下面问题 D:\node\npm.cmd run dev npm WARN logfile could not be created: Error: EPERM: operation not permitted, open D:\node\node_cache\_logs\2025-01-31T01_01_58_076Z-debug-0.log npm WARN logfile could not be created:…

第一个Python程序

目录 1.命令行模式 2.Python交互模式 3.命令行模式和Python交互模式 4.SyntaxError 5.小结 2.使用文本编辑器 1.Visual Studio Code! 2.直接运行py文件 3.输入和输出 1.输出 2.输入 3.小结 在正式编写第一个Python程序前,我们先复习一下什么是命令行模式…

调音基础学习

1、降噪 本质是噪声门,原理就是 1、设置阈值 2、低于阈值的电平全部滤掉 3、高于阈值的电平全部通过(包括环境音) 所以,阈值设置在大于环境音高一点点 降噪不能用于唱歌的录音,会损坏声音动态,所以用…

基于Django的个人博客系统的设计与实现

【Django】基于Django的个人博客系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 系统采用Python作为主要开发语言,结合Django框架构建后端逻辑,并运用J…

AI(计算机视觉)自学路线

本文仅用来记录一下自学路线方便日后复习,如果对你自学有帮助的话也很开心o(* ̄▽ ̄*)ブ B站吴恩达机器学习->B站小土堆pytorch基础学习->opencv相关知识(Halcon或者opencv库)->四类神经网络(这里跟…

C_C++输入输出(下)

C_C输入输出&#xff08;下&#xff09; 用两次循环的问题&#xff1a; 1.一次循环决定打印几行&#xff0c;一次循环决定打印几项 cin是>> cout是<< 字典序是根据字符在字母表中的顺序来比较和排列字符串的&#xff08;字典序的大小就是字符串的大小&#xff09;…