使用 Keras 训练一个卷积神经网络(CNN)(入门篇)

news/2024/11/19 18:31:24/

在上一篇文章中,我们介绍了如何使用 Keras 训练一个简单的全连接神经网络(MLP)。本文将带你深入学习如何使用 Keras 构建和训练一个卷积神经网络(CNN),用于图像分类任务。我们将继续使用 MNIST 数据集,但这次我们将采用更适合图像数据的 CNN 架构。

目录

  1. 环境准备
  2. 导入必要的库
  3. 加载和预处理数据
  4. 构建卷积神经网络模型
  5. 编译模型
  6. 训练模型
  7. 评估模型
  8. 保存和加载模型
  9. 可视化训练过程
  10. 总结

1. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

2. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于数据可视化。

3. 加载和预处理数据

我们继续使用 Keras 自带的 MNIST 数据集。

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 查看数据形状
print(f"训练数据形状: {x_train.shape}, 训练标签形状: {y_train.shape}")
print(f"测试数据形状: {x_test.shape}, 测试标签形状: {y_test.shape}")# 数据预处理
# 归一化:将像素值缩放到 0-1 之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0# CNN 需要添加通道维度
x_train = np.expand_dims(x_train, -1)  # 形状变为 (60000, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)    # 形状变为 (10000, 28, 28, 1)# 将标签转换为分类编码
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 可视化部分数据
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(x_train[i].reshape(28, 28), cmap=plt.cm.binary)plt.xlabel(np.argmax(y_train[i]))
plt.show()

说明:

  • CNN 需要输入数据具有通道维度,因此使用 np.expand_dims 添加一个维度。
  • MNIST 数据集是灰度图像,因此通道维度为 1。

4. 构建卷积神经网络模型

我们将构建一个简单的 CNN 模型,包含两个卷积层和两个池化层,最后接上全连接层进行分类。

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # 卷积层,32 个 3x3 卷积核layers.MaxPooling2D((2, 2)),  # 最大池化层,池化窗口 2x2layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层,64 个 3x3 卷积核layers.MaxPooling2D((2, 2)),  # 最大池化层layers.Flatten(),  # 展平层layers.Dense(64, activation='relu'),  # 全连接层,64 个神经元layers.Dense(num_classes, activation='softmax')  # 输出层,10 个神经元
])# 查看模型结构
model.summary()

说明:

  • Conv2D: 二维卷积层,用于提取图像特征。
  • MaxPooling2D: 最大池化层,用于下采样,减少参数数量。
  • Flatten: 将多维输入一维化,以便连接全连接层。
  • Dense: 全连接层,用于分类。

5. 编译模型

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和交叉熵损失函数。
  • 评估指标为准确率。

6. 训练模型

# 设置训练参数
batch_size = 128
epochs = 10# 训练模型
history = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_split=0.1)  # 使用 10% 的训练数据作为验证集

说明:

  • 使用 10% 的训练数据作为验证集,以监控模型在验证集上的性能。

7. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\n测试准确率: {test_acc:.4f}")

8. 保存和加载模型

# 保存模型
model.save("mnist_cnn_model.h5")# 加载模型
new_model = keras.models.load_model("mnist_cnn_model.h5")

9. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')plt.show()

说明:

  • 通过可视化训练过程中的准确率和损失值,可以帮助我们了解模型的训练情况,判断是否存在过拟合或欠拟合。

10. 本节回顾

本节介绍了如何使用 Keras 构建和训练一个简单的卷积神经网络(CNN),用于手写数字识别任务。主要步骤包括:

  1. 环境准备和库导入: 确保安装了必要的库,并导入所需模块。
  2. 数据加载和预处理: 加载 MNIST 数据集,进行归一化,并添加通道维度。
  3. 构建 CNN 模型: 使用 Conv2D、MaxPooling2D、Flatten、Dense 等层构建模型。
  4. 编译模型: 指定优化器、损失函数和评估指标。
  5. 训练模型: 使用训练数据训练模型,并使用验证集监控性能。
  6. 评估模型: 在测试集上评估模型性能。
  7. 保存和加载模型: 将训练好的模型保存到磁盘,并可加载进行预测。
  8. 可视化训练过程: 通过绘制准确率和损失值曲线,了解模型的训练情况。

通过这个基础教程,你可以开始自行探索更复杂的 CNN 模型和更深入的应用,如图像分类、目标检测、图像分割等。

导师简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师


http://www.ppmy.cn/news/1548292.html

相关文章

linux alsa-lib snd_pcm_open函数源码分析(四)

欢迎直接到博客 linux alsa-lib snd_pcm_open函数源码分析(四) 系列文章其他部分: linux alsa-lib snd_pcm_open函数源码分析(一) linux alsa-lib snd_pcm_open函数源码分析(二) linux alsa-lib snd_pcm_open函数源码分析(三)…

CSS响应式布局实现1920屏幕1rem等于100px

代码解析与实现 设置根元素的 font-size 为 5.208333vw 假设你想让根元素的 font-size 基于视口宽度来动态调整。我们可以通过设置 font-size 为 5.208333vw 来让 1rem 相当于视口宽度的 5.208333%。 计算 5.208333vw: 当屏幕宽度为 1920px 时,5.208333vw 等于 5…

后端返回大数问题

这个问题并不难,但是在开发的时候没有注意到 后端返回了一个列表数据,包含id,这个id是一个大数,列表进入详情,需要将id传入到详情页面详情页面内部通过id获取数据一直404,id不正确找问题,从路由传参到请求数据发现id没有问题,然后和后端进行联调,发现后端返回的id和我获取的id…

django从入门到精通(五)——表单与模型

好的,下面将详细介绍 Django 的表单与模型,包括它们的定义、使用、如何在 Django Admin 中结合使用,以及相关的字段类型和验证机制。 Django 模型与表单 1. Django 模型 Django 模型是一个 Python 类,用于定义数据库中的数据结…

【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)

& 第四届信号处理与通信技术国际学术会议(SPCT 2024) 2024 4th International Conference on Signal Processing and Communication Technology 2024年12月27-29日 中国深圳 www.icspct.com 第四届信号处理与通信技术国际学术会议&#x…

Python 正则表达式进阶用法:字符集与字符范围详解

Python 正则表达式进阶用法:字符集与字符范围详解 正则表达式是文本处理和数据清洗中不可或缺的工具。在前面的学习中,我们已经了解了基本的正则表达式匹配,如匹配单个字符、字符串开始和结束的位置等。今天,我们将进入正则表达式…

定时清理潜在客户列表中的无效邮箱可提高EDM电子邮件自动化营销邮件送达率

定时清理无效邮箱对于邮件营销来说,具有多重好处,这些好处直接关系到营销活动的效率、成本节约、品牌形象以及法律合规性。以下是几个关键方面: 提高邮件送达率: 无效邮箱(如不存在、拼写错误或已废弃的邮箱地址&…

Filebeat升级秘籍:解锁日志收集新境界

文章目录 一、什么是filebeat二、Filebeat的工作原理2.1 filebeat的构成2.1.1 Prospector 组件2.1.2 Harvester 组件 2.2 filebeat如何保存文件的状态2.3 filebeat何如保证至少一次数据消费 三、Filebeat配置文件四、filebeat对比fluented五、Filebeat的部署安装5.1裸金属安装5…