卷积神经网络(CNN)入门:使用Python实现手写数字识别

news/2024/11/28 15:44:40/

在上一篇文章中,我们介绍了如何使用Python实现一个简单的前馈神经网络。本文将重点介绍卷积神经网络(CNN),这是一种在计算机视觉任务中表现优异的深度学习模型。我们将从卷积神经网络的基本原理开始,介绍卷积层、池化层和全连接层等概念,然后使用Python和Keras库实现一个手写数字识别的例子。

1.卷积神经网络(CNN)简介

卷积神经网络(Convolutional Neural Networks,简称CNN)是一种特殊的前馈神经网络,主要用于处理具有类似网格结构的数据,如图像。与传统神经网络相比,卷积神经网络在图像识别、语音识别等领域具有更好的性能。CNN的主要优势在于其能够自动学习局部特征并组合成全局特征,有效减少模型参数,降低过拟合的风险。

2.CNN的基本组成

卷积神经网络主要由三种类型的层组成:卷积层、池化层和全连接层。

2.1. 卷积层

卷积层是CNN的核心部分。卷积层通过在输入数据上滑动一个卷积核,从而捕捉局部特征。卷积核的大小和数量是网络的超参数,可以根据具体任务进行调整。

2.2. 池化层

池化层用于降低特征图的空间尺寸,从而减少计算量和参数。最常见的池化操作是最大池化(Max Pooling)和平均池化(Average Pooling)。

2.3. 全连接层

全连接层用于将特征图展平并输出最终的分类结果。通常,全连接层位于卷积神经网络的末端。

3.使用Python和Keras实现手写数字识别

在本节中,我们将使用Python和Keras库实现一个简单的卷积神经网络,用于识别手写数字。首先,我们需要安装Keras库,并引入所需的库和模块。

!pip install kerasimport numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from keras.utils import to_categorical

接下来,我们加载MNIST数据集,并对数据进行预处理。

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 归一化
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0# 扩展数据维度
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)# 对标签进行One-hot编码
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

现在,我们构建卷积神经网络模型。

model = Sequential()# 添加第一个卷积层和池化层
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))# 添加第二个卷积层和池化层
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))# 添加展平层
model.add(Flatten())# 添加全连接层
model.add(Dense(128, activation='relu'))# 添加Dropout层
model.add(Dropout(0.5))# 添加输出层
model.add(Dense(10, activation='softmax'))

接下来,我们编译模型并设置优化器、损失函数和评价指标。

model.compile(optimizer=Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])

现在,我们训练模型。

model.fit(x_train, y_train,batch_size=128,epochs=10,verbose=1,validation_data=(x_test, y_test))

最后,我们评估模型在测试集上的性能。

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

本文向您详细介绍了卷积神经网络(CNN)的基本原理和组成部分,并使用Python和Keras库实现了一个简单的手写数字识别模型。通过这个案例,您可以更好地理解卷积神经网络的原理和实现过程。在后续的文章中,我们将继续深入探讨神经网络的其他类型和技术,帮助您更好地应用神经网络解决实

在本文中,我们将进一步深入介绍卷积神经网络的原理,并详细展示如何使用Python和Keras构建一个手写数字识别模型。

首先,我们简要回顾一下卷积神经网络的基本组成部分。

  1. 卷积层:通过在输入数据上滑动一个卷积核,从而捕捉局部特征。
  2. 激活函数:通常在卷积层之后添加激活函数,以引入非线性。常见的激活函数有ReLU、Sigmoid和Tanh等。
  3. 池化层:用于降低特征图的空间尺寸,从而减少计算量和参数。最常见的池化操作是最大池化和平均池化。
  4. 全连接层:将特征图展平并输出最终的分类结果。通常,全连接层位于卷积神经网络的末端。

现在,我们使用Python和Keras库构建一个手写数字识别的卷积神经网络。

数据预处理

在加载MNIST数据集之后,我们需要对其进行预处理。这包括归一化、调整数据维度以适应卷积神经网络的输入要求,以及对标签进行One-hot编码。

# 归一化
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0# 扩展数据维度
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)# 对标签进行One-hot编码
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

构建模型

接下来,我们构建卷积神经网络模型。在这个示例中,我们将使用两个卷积层、两个池化层和一个全连接层。

model = Sequential()# 添加第一个卷积层和激活函数
model.add(Conv2D(32, kernel_size=(3, 3), input_shape=(28, 28, 1)))
model.add(Activation('relu'))# 添加第一个池化层
model.add(MaxPooling2D(pool_size=(2, 2)))# 添加第二个卷积层和激活函数
model.add(Conv2D(64, kernel_size=(3, 3)))
model.add(Activation('relu'))# 添加第二个池化层
model.add(MaxPooling2D(pool_size=(2, 2)))# 添加展平层
model.add(Flatten())# 添加全连接层和激活函数
model.add(Dense(128))
model.add(Activation('relu'))# 添加Dropout层
model.add(Dropout(0.5))# 添加输出层和激活函数
model.add(Dense(10))
model.add(Activation('softmax'))
``

编译模型

在模型构建完成后,我们需要编译模型并设置优化器、损失函数和评价指标。

model.compile(optimizer=Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])

训练模型

现在,我们可以开始训练模型。我们将使用批量梯度下降法训练模型,设置批量大小为128,训练10个周期。

 
history = model.fit(x_train, y_train,batch_size=128,epochs=10,verbose=1,validation_data=(x_test, y_test))

在训练过程中,我们可以通过history对象跟踪训练和验证的损失和准确率。这可以帮助我们诊断模型是否过拟合或欠拟合。

可视化训练过程

为了更直观地了解模型训练过程中的性能变化,我们可以将训练和验证的损失和准确率绘制在图表上。

import matplotlib.pyplot as plt# 绘制损失曲线
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show()# 绘制准确率曲线
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

评估模型性能

训练完成后,我们可以评估模型在测试集上的性能。

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

在本文中,我们详细介绍了卷积神经网络的原理,并使用Python和Keras构建了一个手写数字识别模型。我们还展示了如何在训练过程中跟踪模型性能并可视化训练过程。这些知识将帮助您更好地理解卷积神经网络的原理和实现过程,为您在深度学习领域的进一步探索奠定基础。在后续的文章中,我们将继续深入探讨神经网络的其他类型和技术,帮助您更好地应用神经网络解决实际问题。


http://www.ppmy.cn/news/41770.html

相关文章

读什么书能让你进入高层次

“黄金非宝书为宝,万事皆空善不空。”我很少看到有人说读书不好的,但却很少看到有人读好书。好书、好读书、读好书,都是很稀缺的。好书的作用基本上,我们遇到的每个困惑,都有一本书能够给出解答。因为你的困惑并不独特…

Android11.0 系统Framework发送通知流程分析

1.前言 在android 11.0的系统rom定制化开发中,在systemui中一个重要的内容就是系统通知的展示,在状态栏展示系统发送通知的图标,而在 系统下拉通知栏中展示接收到的系统发送过来的通知,所以说对系统framework中发送通知的流程分析很重要,接下来就来分析下系统 通知从fram…

Redis源码之SDS简单动态字符串

Redis 是内存数据库,高效使用内存对 Redis 的实现来说非常重要。 看一下,Redis 中针对字符串结构针对内存使用效率做的设计优化。 一、SDS的结构 c语言没有string类型,本质是char[]数组;而且c语言数组创建时必须初始化大小&#…

19学习提升:gRPC源码中的那些优秀设计(上)

gRPC作为高性能的RPC框架,离不开它优雅的设计和编码,无论是作为一名底层开发者还是上层的业务开发者,能够写出一手好的代码一直都是决定自身水平高低的一个重要体现,如果想要达到一个较高层次的水平,离不开长时间的学习和训练以及不断的感悟,而一些优秀的开源软件和框架往…

Linux使用:环境变量指南和CPU和GPU利用情况查看

Linux使用:环境变量指南和CPU和GPU利用情况查看Linux环境变量初始化与对应文件的生效顺序Linux的变量种类设置环境变量直接运行export命令定义变量修改系统环境变量修改用户环境变量修改环境变量配置文件环境配置文件的区别profile、 bashrc、.bash_profile、 .bash…

maven使用教程

文章目录IDEA创建maven项目maven项目必有得目录结构项目构建关键字cleanvalidatecompiletestpackageverifyinstallsitedeploy命令使用方法方法一 在terminal终端执行方法二 在右侧得maven中双击依赖管理在pom.xml下 导包、scope的传递范围、打包方式依赖冲突声明优先原则就近原…

算法学习|动态规划 LeetCode 300.最长递增子序列、674. 最长连续递增序列、718. 最长重复子数组

动态规划一、最长递增子序列思路实现代码二、最长**连续**递增序列思路实现代码三、最长重复子数组思路实现代码一、最长递增子序列 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除…

超市购物系统【GUI/Swing+MySQL】(Java课设)

系统类型 Swing窗口类型Mysql数据库存储数据 使用范围 适合作为Java课设!!! 部署环境 jdk1.8Mysql8.0Idea或eclipsejdbc 运行效果 本系统源码地址:https://download.csdn.net/download/qq_50954361/87682510 更多系统资源库…