迁移学习CNN图像分类模型 - 花朵图片分类

news/2025/2/21 6:00:42/

训练一个好的卷积神经网络模型进行图像分类不仅需要计算资源还需要很长的时间。特别是模型比较复杂和数据量比较大的时候。普通的电脑动不动就需要训练几天的时间。为了能够快速地训练好自己的花朵图片分类器,我们可以使用别人已经训练好的模型参数,在此基础之上训练我们的模型。这个便属于迁移学习。本文提供训练数据集和代码下载。
这里写图片描述
这里写图片描述

原理:卷积神经网络模型总体上可以分为两部分,前面的卷积层和后面的全连接层。卷积层的作用是图片特征的提取,全连接层作用是特征的分类。我们的思路便是在inception-v3网络模型上,修改全连接层,保留卷积层。卷积层的参数使用的是别人已经训练好的,全连接层的参数需要我们初始化并使用我们自己的数据来训练和学习。

这里写图片描述

上面inception-v3模型图红色箭头前面部分是卷积层,后面是全连接层。我们需要修改修改全连接层,同时把模型的最终输出改为5。

由于这里使用了tensorflow框架,所以,我们需要获取上图红色箭头所在位置的张量BOTTLENECK_TENSOR_NAME(最后一个卷积层激活函数的输出值,个数为2048)以及模型最开始的输入数据的张量JPEG_DATA_TENSOR_NAME。获取这两个张量的作用是,图片训练数据通过JPEG_DATA_TENSOR_NAME张量输入模型,通过BOTTLENECK_TENSOR_NAME张量获取通过卷积层之后的图片特征。

BOTTLENECK_TENSOR_SIZE = 2048
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'

通过下面的链接下载inception-v3模型,其中包含已经训练好的参数。

模型下载链接:地址

or https://pan.baidu.com/s/1LxBK5annrmiWSXE_jajOJQ

训练数据花朵图片下载:地址

通过下面的代码加载模型,同时获取上面所述的两个张量。

   # 读取已经训练好的Inception-v3模型。with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

由于我们模型的功能是对五种花进行分类,所以,我们需要修改全连接层,这里,我们只增加一个全连接层。全连接层的输入数据便是BOTTLENECK_TENSOR_NAME张量。

	# 定义一层全链接层with tf.name_scope('final_training_ops'):weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))biases = tf.Variable(tf.zeros([n_classes]))logits = tf.matmul(bottleneck_input, weights) + biasesfinal_tensor = tf.nn.softmax(logits)

最后便是定义交叉熵损失函数。模型使用反向传播训练,而训练的参数并不是模型的所有参数,仅仅是全连接层的参数,卷积层的参数是不变的。

    # 定义交叉熵损失函数。cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)cross_entropy_mean = tf.reduce_mean(cross_entropy)train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)

那么接下来的是如何给我们的模型输入数据了,这里提供了几个操作数据的函数。由于训练数据集比较小,先把所有的图片通过JPEG_DATA_TENSOR_NAME张量输入模型,然后获取BOTTLENECK_TENSOR_NAME张量的值并保存到硬盘中。在模型训练的时候,从硬盘中读取所保存的BOTTLENECK_TENSOR_NAME张量的值作为全连接层的输入数据。因为一张图片可能会被使用多次。

# 输入图片并获取`BOTTLENECK_TENSOR_NAME`张量的值
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor)# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于训练
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor):# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于测试。
def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)

不到5分钟就可以训练好我们的模型,精确度还蛮高的。下图是本人运行的结果。

这里写图片描述

源码地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning


http://www.ppmy.cn/news/438757.html

相关文章

通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类

实验目的 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类。 实验环境 import tensorflow as tfprint(tf.__version__)output: 2.3.0 实验步骤 (一) 数据获取和预处理 1.1 数据选择 TensorFlow 官方提供的花朵…

CNN实现花卉图片分类识别

CNN实现花卉图片分 前言 本文为一个利用卷积神经网络实现花卉分类的项目,因此不会过多介绍卷积神经网络的基本知识。此项目建立在了解卷积神经网络进行图像分类的原理上进行的。 项目简介 本项目为一个图像识别项目,基于tensorflow,利用C…

抓取花卉图片

对比用request抓取而言,使用selenium库会更简便抓取 话不多说现在开始: 首先我们要配置一下chromedriver: 1、chromedirver 下载网站https://registry.npmmirror.com/binary.html?pathchromedriver/下载与自己对应的谷歌版本 查看谷歌版本 如我自己…

调试厉器addr2line

addr2line: 将地址转换为文件名和行号的命令行工具 在C/C程序的调试过程中,我们通常会使用调试器(如GDB)来定位崩溃或错误的位置。但有时候,我们可能只能获得程序崩溃时的地址,而没有调试器的支持。这时候&#xff0c…

如保查看wifi无线的mac地址

使用命令行,运行ipconfig /all 前提是保证无线网卡未被禁用。 找到无线局域网的物理地址。 以太网的特理地址,是网卡的mac地址。

查看wifi连接路由器的MAC地址

windows连接wifi ,通过cmd运行如下命令,查看 netsh wlan show networks modebssid

更改WLAN的IP地址

网络【右键】–>打开“网络和internet”设置【左键】–>高级网络设置–>更改适配器选项【左键】–>WLAN【右键】–>属性【左键】–>internet协议版本4(TCP/IPV4)【左键双击】–>更改IP地址和DNS服务器–>【确定】

蓝牙MAC地址认证以及WiFi MAC地址认证

有没有想过,手机,或者蓝牙耳机,蓝牙音响,产品需要链接蓝牙的时候,是通过怎么样的一个方法来识别那个产品对应的是那个蓝牙呢, 有没有想过,当你手机打开WIFI ,想要去链接WIFI的时候,有…