Python从0到100(八十五):神经网络-使用迁移学习完成猫狗分类

embedded/2025/1/24 1:24:52/

在这里插入图片描述

前言: 零基础学Python:Python从0到100最新最全教程。 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Python爬虫、Web开发、 计算机视觉、机器学习、神经网络以及人工智能相关知识,成为学习学习和学业的先行者!
欢迎大家订阅专栏:零基础学Python:Python从0到100最新最全教程!

今天来学习一下如何使用基于tensorflow和keras迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

python">import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

python">path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

python">train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在这里插入图片描述
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
在这里插入图片描述

2.图片展示

下面是数据集中的图片展示:

python">class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

python">val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络——MobileNetV2

使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩
在这里插入图片描述

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:
在这里插入图片描述
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride
在这里插入图片描述

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
在这里插入图片描述

四、搭建迁移学习

1.训练

python">inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
python">IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
python">base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

2.训练结果可视化

用图表显示准确率和损失函数

python"># 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

3.输出训练的准确率

python"># 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

4.用cnn工具可视化一批数据的预测结果

python">label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

5.数据输出

python"># 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

6.用cnn工具可视化一个数据样本的各层输出

python">cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

文末送书

本期推荐1:

《Java面向对象程序设计:AI大模型给程序员插上翅膀》
AI工具助力Java编程:故事引领思政,AI助力学习;任务驱动实践,项目提升能力。
在这里插入图片描述

京东:https://item.jd.com/14850722.html

从AI助力角度出发,轻松学习编程
故事引入思政,引发读者动手实践
引出目标任务,明确学习目的和方向
AI学习问答与同步训练,提升学习效率
丰富的学习资源,助力实际项目开发
内容简介
随着云计算、物联网、大数据、人工智能等新一代信息技术的发展,Java 作为一种高性能、跨平台的编程语言,有着广泛的应用。本书从应用的角度详尽介绍了Java开发的核心技术。
全书分为12章,主要介绍了Java开发环境、Java编程基础、类和对象、继承和多态、抽象类和接口、Java常用类、内部类和泛型、集合容器、JDBC编程、图形用户界面设计、多线程,最后通过企业项目管理的方式进行实践,实现一个完整案例。
本书每章都通过故事的方式引入思政,并且从故事中引出目标任务。针对目标任务,辅以人工智能工具(ChatGPT、文心一言、讯飞星火)的帮助,得到行之有效的示例。之后对其进行知识解析,并完成上机练习。通过相关的练习巩固知识,并在合适的阶段引入一些常见的算法,加强学生的逻辑思维能力。在每章末尾有AI学习问答,让读者自行探索,同时加入同步训练,加强学习效果。

本期推荐2:

《Python金融大数据分析》
掌握Python,从零到一速成金融分析高手!实战案例深剖,让数字说话,让决策更精准!深入了解金融数据分析的具体过程和方法,提高实操能力。附赠书中案例源代码。
在这里插入图片描述

京东:https://item.jd.com/14827368.html

系统:全面构建Python金融大数据分析框架,从零到一,系统掌握核心技能,让学习之路有条不紊。
经典:凝聚笔者多年智慧,解读大数据在金融领域的应用,确保学习内容前沿且可靠。
深入:深度剖析Python在金融大数据分析中的关键技术,直击核心难点,助您深入理解数据背后的价值。
案例:精选实战案例,让您在真实场景中磨炼技能,实现从理论到实践的完美跨越。
内容简介
本书共分为11 章,全面介绍了以Python为工具的金融大数据的理论和实践,特别是量化投资和交易领域的相关应用,并配有项目实战案例。书中涵盖的内容主要有Python概览,结合金融场景演示Python的基本操作,金融数据的获取及实战,MySQL数据库详解及应用,Python在金融大数据分析方面的核心模块详解,金融分析及量化投资,Python量化交易,数据可视化Matplotlib,基于NumPy的股价统计分析实战,基于Matplotlib的股票技术分析实战,以及量化交易策略实战案例等。
本书内容通俗易懂,案例丰富,实用性强,特别适合以下人群阅读:金融行业的从业者、数据分析师、量化投资者、希望提高数据分析能力的投资者,以及对大数据分析感兴趣的编程人员。另外,本书也适合作为相关培训机构的教材。


http://www.ppmy.cn/embedded/156446.html

相关文章

1. 基于图像的三维重建

1. 基于图像的三维重建 核心概念三维重建中深度图、点云的区别?深度图点云总结 深度图到点云还需要什么步骤?1. **获取相机内参**2. **生成相应的像素坐标**3. **计算三维坐标**4. **构建点云**5. **处理颜色信息(可选)**6. **去除…

centos哪个版本建站好?centos最稳定好用的版本

在信息化飞速发展的今天,服务器操作系统作为构建网络架构的基石,其稳定性和易用性成为企业和个人用户关注的重点。CentOS作为一款广受欢迎的开源服务器操作系统,凭借其强大的性能、出色的稳定性和丰富的软件包资源,成为众多用户建…

Windows7搭建Hadoop-2.7.3源码阅读环境问题解决列表

个人博客地址:Window7搭建Hadoop-2.7.3源码阅读环境问题解决列表 | 一张假钞的真实世界 环境说明 Windows 7java version “1.7.0_80”Apache Maven 3.2.3ProtocolBuffer 2.5.0cmake version 3.7.2 win64 x64Windows SDK 7.1构建过程参照源代码目录下BUILDING.txt说明文件中的…

YOLOv8改进,YOLOv8检测头融合DSConv(动态蛇形卷积),并添加小目标检测层(四头检测),适合目标检测、分割等

精确分割拓扑管状结构例如血管和道路,对各个领域至关重要,可确保下游任务的准确性和效率。然而,许多因素使任务变得复杂,包括细小脆弱的局部结构和复杂多变的全局形态。在这项工作中,注意到管状结构的特殊特征,并利用这一知识来引导 DSCNet 在三个阶段同时增强感知:特征…

Vue3+Element Plus 实现 el-table 表格组件滚动是否触底监听判断

问题描述 Element Plus 中的 el-table 组件暴露出了 scroll 事件,表格被用户滚动后会触发,暴露出横向和竖向的滚动距离,未暴露出表格的DOM对象。 ({ scrollLeft: number, scrollTop: number }) > void此时,可以通过表格的引用…

摄影交流平台项目Uniapp+Springboot已完成

后端项目结构 前端项目结构 前端效果

深度学习python基础(第三节) 函数、列表

本节主要介绍函数、列表的基本语法格式。 函数 与c语言的函数差不多,就是语法基本格式不同。 name "loveyou" length len(name) print("字符串的长度为:%d" % length) # 自定义函数 def countstr(data):count 0for i in da…

Windows 通过 openssh 连接 Ubuntu 24.04 LTS

Ubuntu 24.04 LTS Ubuntu 配置 sudo apt update sudo apt install openssh-server sudo systemctl start ssh sudo systemctl enable ssh sudo systemctl status ssh sudo ufw status sudo ufw allow ssh sudo ufw reload sudo ufw status安装 OpenSSH 服务器 首先&#xff…