深度学习实战:使用卷积神经网络(CNN)进行图像分类

news/2025/1/23 23:08:57/

在当今的机器学习领域,深度学习,尤其是卷积神经网络(CNN),已经在图像分类、物体检测、自然语言处理等领域取得了巨大的成功。本文将通过一个实际的例子,展示如何使用TensorFlow和Keras库构建一个卷积神经网络来进行图像分类。我们将使用经典的CIFAR-10数据集,该数据集包含60000张32x32的彩色图像,分为10个类别。

环境准备

首先,确保你已经安装了TensorFlow。你可以使用以下命令安装:


pip install tensorflow

数据集加载

CIFAR-10数据集是Keras库自带的数据集之一,我们可以直接加载:


import tensorflow as tffrom tensorflow.keras.datasets import cifar10from tensorflow.keras.utils import to_categorical# 加载数据集(x_train, y_train), (x_test, y_test) = cifar10.load_data()# 数据归一化到[0, 1]范围x_train, x_test = x_train / 255.0, x_test / 255.0# 将标签转换为one-hot编码y_train = to_categorical(y_train, 10)y_test = to_categorical(y_test, 10)

构建CNN模型

接下来,我们定义一个简单的卷积神经网络模型:


from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropoutmodel = Sequential([Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),MaxPooling2D((2, 2)),Conv2D(64, (3, 3), activation='relu'),MaxPooling2D((2, 2)),Conv2D(64, (3, 3), activation='relu'),Flatten(),Dense(64, activation='relu'),Dropout(0.5),Dense(10, activation='softmax')])

编译和训练模型

在训练模型之前,我们需要编译模型,指定损失函数、优化器和评估指标:


model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 训练模型history = model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test))

评估模型

训练完成后,我们可以在测试集上评估模型的性能:


test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)print(f'Test accuracy: {test_acc}')

可视化训练过程

为了更好地理解模型的训练过程,我们可以可视化损失和准确率的变化:


import matplotlib.pyplot as plt# 绘制训练和验证的准确率变化plt.plot(history.history['accuracy'], label='accuracy')plt.plot(history.history['val_accuracy'], label = 'val_accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.ylim([0, 1])plt.legend(loc='lower right')plt.show()# 绘制训练和验证的损失变化plt.plot(history.history['loss'], label='loss')plt.plot(history.history['val_loss'], label = 'val_loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc='upper right')plt.show()

总结

通过以上步骤,我们成功构建了一个简单的卷积神经网络,并在CIFAR-10数据集上进行了训练和评估。这个模型虽然简单,但已经能够在测试集上达到不错的准确率。你可以尝试调整模型的架构、增加更多的层、使用不同的优化器或正则化技术,以进一步提高模型的性能。

完整的代码如下:


import tensorflow as tffrom tensorflow.keras.datasets import cifar10from tensorflow.keras.utils import to_categoricalfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropoutimport matplotlib.pyplot as plt# 加载数据集(x_train, y_train), (x_test, y_test) = cifar10.load_data()# 数据归一化到[0, 1]范围x_train, x_test = x_train / 255.0, x_test / 255.0# 将标签转换为one-hot编码y_train = to_categorical(y_train, 10)y_test = to_categorical(y_test, 10)# 构建模型model = Sequential([Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),MaxPooling2D((2, 2)),Conv2D(64, (3, 3), activation='relu'),MaxPooling2D((2, 2)),Conv2D(64, (3, 3), activation='relu'),Flatten(),Dense(64, activation='relu'),Dropout(0.5),Dense(10, activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 训练模型history = model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test))# 评估模型test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)print(f'Test accuracy: {test_acc}')# 可视化训练和验证的准确率变化plt.plot(history.history['accuracy'], label='accuracy')plt.plot(history.history['val_accuracy'], label = 'val_accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.ylim([0, 1])plt.legend(loc='lower right')plt.show()# 可视化训练和验证的损失变化plt.plot(history.history['loss'], label='loss')plt.plot(history.history['val_loss'], label = 'val_loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc='upper right')plt.show()

http://www.ppmy.cn/news/1565604.html

相关文章

IOS 安全机制拦截 window.open

摘要 在ios环境,在某些情况下执行window.open不生效 一、window.open window.open(url, target, windowFeatures) 1. url:「可选参数」,表示你要加载的资源URL或路径,如果不传,则打开一个url地址为about:blank的空…

线上突发:MySQL 自增 ID 用完,怎么办?

线上突发:MySQL 自增 ID 用完,怎么办? 1. 问题背景2. 场景复现3. 自增id用完怎么办?4. 总结 1. 问题背景 最近,我们在数据库巡检的时候发现了一个问题:线上的地址表自增主键用的是int类型。随着业务越做越…

.Net Core微服务入门全纪录(四)——Ocelot-API网关(上)

系列文章目录 1、.Net Core微服务入门系列(一)——项目搭建 2、.Net Core微服务入门全纪录(二)——Consul-服务注册与发现(上) 3、.Net Core微服务入门全纪录(三)——Consul-服务注…

深入了解 Linux 的虚拟内存管理机制:Swap 机制

文章目录 深入了解 Linux 的 Swap 机制一、什么是 Swap?二、Swap 的工作原理三、Swap 的类型四、Swap 的使用场景五、配置 Swap六、Swap 的性能影响七、如何优化 Swap 使用八、总结 深入了解 Linux 的 Swap 机制 在 Linux 操作系统中,Swap 是一种虚拟内…

十一、apply家族(4)

tapply()函数 tapply()函数主要是用于对一个因子或因子列表,执行指定的函数调用,最后获得汇总信息。 tapply()函数的使用格式如下所示。 tapply(x, INDEX, FUN, ...&am…

centos 安全配置基线

CentOS 是一个广泛使用的操作系统,为了确保系统的安全性,需要遵循一系列的安全基线。以下是详细的 CentOS 安全基线配置建议: 通过配置核查,CentOS操作系统未安装入侵防护软件,无法检测到对重要节点进行入侵的 解决方案: 安装入侵…

【线性代数】基础版本的高斯消元法

[精确算法] 高斯消元法求线性方程组 线性方程组 考虑线性方程组, 已知 A ∈ R n , n , b ∈ R n A\in \mathbb{R}^{n,n},b\in \mathbb{R}^n A∈Rn,n,b∈Rn, 求未知 x ∈ R n x\in \mathbb{R}^n x∈Rn A 1 , 1 x 1 A 1 , 2 x 2 ⋯ A 1 , n x n b 1…

【jmeter】下载及使用教程【mac】

1.安装java 打开 Java 官方下载网站https://www.oracle.com/java/technologies/downloads/选择您想要下载的 Java 版本,下载以 .dmg 结尾的安装包,注意 JMeter 需要 Java 8下载后打开安装包点击“安装”按钮即可 2.下载jmeter 打开 Apache JMeter 官方…