Python从0到100(八十四):神经网络-卷积神经网络训练CIFAR-10数据集

ops/2025/1/19 15:13:23/

在这里插入图片描述

前言: 零基础学Python:Python从0到100最新最全教程。 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Python爬虫、Web开发、 计算机视觉、机器学习、神经网络以及人工智能相关知识,成为学习学习和学业的先行者!
欢迎大家订阅专栏:零基础学Python:Python从0到100最新最全教程!

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

python">def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

python">print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

python">img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

python">from tensorflow.keras.datasets import cifar10(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

python">def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

python">gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

python">cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

python">X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

python">model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

python">history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

python">model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

python">print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

python">plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

文末送书

本期推荐1:

《MATLAB数学建模从入门到精通》
从理论框架到实战案例,跨平台通用性强,从模型构建到应用实践,由浅入深掌握数学建模之道,灵活应对复杂问题挑战。
在这里插入图片描述

京东:https://item.jd.com/14827364.html

源于实际:本书全面归纳和整理笔者多年的数学建模教学实践经验,体现了来源于实际服务于实际的原则。
由浅入深:从基础知识开始逐步介绍数学建模的相关知识,学习门槛很低。
通俗易懂:本书力争让晦涩的知识变得通俗易懂。
内容实用:结合大量实例进行讲解,能够有效指导数学建模新手入门。
内容简介
本书结合案例,系统介绍了使用 MATLAB 进行数学建模的相关知识和方法论。
本书分为 11 章,主要包括走进数学建模的世界、函数极值与规划模型、微分方程与差分模型、数据处理的基本策略、权重生成与评价模型、复杂网络与图论模型、时间序列与投资模型、机器学习与统计模型、进化计算与群体智能、其他数学建模知识、数学建模竞赛中的一些基本能力。

本期推荐2:

《ChatGPT时代:GPTs开发详解》
解锁未来科技,GPTs开发应用一本通:从智能对话到无限创意,您的AI创新之旅,由此启程!
在这里插入图片描述

京东:https://item.jd.com/14899622.html

案例详解GPTs技术,助你实现创新。
零基础搭建GPTs,打造个性化助手。
学会快速定制助手,轻松为己所用。
详解学习辅助类GPTs、个人成长类GPTs、办公助手类GPTs、电商和社交媒体类GPTs、短视频制作类GPTs五大类GPTs。
掌握各类助手的搭建和发布,全面提升工作和学习效率,并获得平台收益。
内容简介
定制化GPTs(Custom GPTs)是由OpenAI推出的一种创新技术,它允许用户根据自己的特定需求和应用场景来创建定制版本的GPTs。定制化GPTs结合了用户自定义的指令、额外的专业知识以及多样化的技能,旨在为用户提供日常生活、工作或特定任务中的更多帮助和支持。
本书详细介绍了GPTs的概念、功能特点、创建第一个GPTs、指令使用技巧、知识库使用技巧,以及如何根据个人需求和应用场景创建定制化GPTs,如学习辅助类GPTs、个人成长类GPTs、办公助手类GPTs、电商和社交媒体类GPTs、短视频制作类GPTs等。
本书内容充实,语言通俗易懂,案例丰富,具有很强的可读性。它既适合初次接触AI技术的普通读者阅读,也适合有一定经验的AI从业者借鉴。此外,本书也适合那些需要了解最新ChatGPT技术动态的开发人员阅读。


http://www.ppmy.cn/ops/151410.html

相关文章

QT笔记- Qt6.8.1安卓开发配置

1. Qt添加组件Android, 注意登录时需连接VPN, 选好组件下载时可退出VPN.2. 手动下载安装java的jdk&#xff0c;版本jdk20及以下, 注意, 即便安装最新版本的Android SDK一般也不支持高版本的jdk.3. 安卓SDK, NDK, OpenSSL在Qt设置-设备-Android中按Qt提供的方式安装.4. 环境变量…

2025.1.16——四、get_post 传参方式

题目来源&#xff1a;攻防世界get_post 目录 一、打开靶机&#xff0c;分析信息 ​编辑 二、解题步骤 step 1&#xff1a;GET方式传参 step 2&#xff1a;POST方式传参——hackbar插件&#xff08;火狐浏览器&#xff09; 三、小结 一、打开靶机&#xff0c;分析信息 这…

Ettercap 入门

syntax ettercap [options] [target1] [target2] 没有目标地址或是源地址之分, 因为通信是双向的 如果接口支持IPv6&#xff0c;target 以这样的形式 MAC/IPs/IPv6/Ports否则以MAC/IPs/Ports的形式。 /10.0.0.1/表示所有mac地址&#xff0c;所有端口&#xff0c;只要ip地址是…

Linux 音视频入门到实战专栏(视频篇)视频编解码 MPP

文章目录 一、MPP 介绍二、获取和编译RKMPP库三、视频解码四、视频编码 沉淀、分享、成长&#xff0c;让自己和他人都能有所收获&#xff01;&#x1f604; &#x1f4e2;本篇将介绍如何调用alsa api来进行音频数据的播放和录制。 一、MPP 介绍 瑞芯微提供的媒体处理软件平台…

Redis 缓存穿透、击穿、雪崩 的区别与解决方案

前言 Redis 是一个高性能的键值数据库&#xff0c;广泛应用于缓存、会话存储、实时数据分析等场景。然而&#xff0c;在高并发的环境下&#xff0c;Redis 缓存可能会遇到 缓存击穿、缓存穿透 和 缓存雪崩 这三大问题。这些问题不仅影响系统的稳定性和性能&#xff0c;还经常出…

精准测量,尽在掌握 —— 电导率传感器:科技之水质的守护者

在科技日新月异的今天&#xff0c;我们身边的每一个细节都融入了智能与精准的元素。水质监测&#xff0c;这一关乎人类健康与生态环境的重大课题&#xff0c;同样受益于现代科技的进步。其中&#xff0c;电导率传感器作为水质监测的重要工具&#xff0c;正以它独特的方式&#…

MySQL 篇 - Java 连接 MySQL 数据库并实现数据交互

在现代应用中&#xff0c;数据库是不可或缺的一部分。Java 作为一种广泛使用的编程语言&#xff0c;提供了丰富的 API 来与各种数据库进行交互。本文将详细介绍如何在 Java 中连接 MySQL 数据库&#xff0c;并实现基本的数据交互功能。 一、环境准备 1.1 安装 MySQL 首先&am…

定时任务特辑 Quartz、xxl-job、elastic-job、Cron四个定时任务框架对比,和Spring Boot集成实战

专栏集锦&#xff0c;大佬们可以收藏以备不时之需&#xff1a; Spring Cloud 专栏&#xff1a;http://t.csdnimg.cn/WDmJ9 Python 专栏&#xff1a;http://t.csdnimg.cn/hMwPR Redis 专栏&#xff1a;http://t.csdnimg.cn/Qq0Xc TensorFlow 专栏&#xff1a;http://t.csdni…