EfficientNet-B0详解

news/2024/11/17 6:51:12/

文章转载来自:Dormineered

内容有些词汇翻译不准确,请见谅!!!个人整理不易,包含参数计算内容,更多训练阶段细节会在后期更新。转载请注明

EfficientNets是谷歌大脑的工程师谭明星和首席科学家Quoc V. Le在论文《EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks》中提出。该模型的基础网络架构是通过使用神经网络架构搜索(neural architecture search)设计得到。卷积神经网络模型通常是在已知硬件资源的条件下,进行训练的。当你拥有更好的硬件资源时,可以通过放大网络模型以获得更好的训练结果。为系统的研究模型缩放,谷歌大脑的研究人员针对EfficientNets的基础网络模型提出了一种全新的模型缩放方法,该方法使用简单而高效的复合系数来权衡网络深度、宽度和输入图片分辨率。

通过放大EfficientNets基础模型,获得了一系列EfficientNets模型。该系列模型在效率和准确性上战胜了之前所有的卷积神经网络模型。尤其是EfficientNet-B7在ImageNet数据集上得到了top-1准确率84.4%和top-5准确率97.1%的结果。且它和当时准确率最高的其它模型对比,大小缩小了8.4倍,效率提高了6.1倍。且通过迁移学习,EfficientNets在多个知名数据集上均达到了当时最先进的水平。

EfficientNets网络模型结构

本案例中我们选用EfficientNets系列中的基础网络模型EfficientNet-B0。 当该模型在ImageNet数据集上训练时,其一共包含5330564个参数,其中需要梯度下降来训练的参数有5288548个。不需要训练的参数是Batch Normalization层中的均值和方差共42016个。该网络的核心结构为移动翻转瓶颈卷积(mobile inverted bottleneck convolution,MBConv)模块,该模块还引入了压缩与激发网络(Squeeze-and-Excitation Network,SENet)的注意力思想,SENet在提出时也在ImageNet数据集上达到了当时最高的准确率。

移动翻转瓶颈卷积也是通过神经网络架构搜索得到的,该模块结构与深度分离卷积(depthwise separable convolution)相似,该移动翻转瓶颈卷积首先对输入进行1x1的逐点卷积并根据扩展比例(expand ratio)改变输出通道维度(如扩展比例为3时,会将通道维度提升3倍。但如果扩展比例为1,则直接省略该1x1的逐点卷积和其之后批归一化和激活函数)。接着进行kxk的深度卷积(depthwise convolution)。如果要引入压缩与激发操作,该操作会在深度卷积后进行。再以1x1的逐点卷积结尾恢复原通道维度。最后进行连接失活(drop connect)和输入的跳越连接(skip connection),这一做法源于论文《Deep networks with stochastic depth》,它让模型具有了随机的深度,剪短了模型训练所需的时间,提升了模型性能(注意,在EfficientNets中,只有当相同的移动翻转瓶颈卷积重复出现时,才会进行连接失活和输入的跳越连接,且还会将其中的深度卷积步长变为1),连接失活是一种类似于随机失活(dropout)的操作,并且在模块的开始和结束加入了恒等跳越。注意该模块中的每一个卷积操作后都会进行批归一化,激活函数使用的是Swish激活函数。

移动翻转瓶颈卷积模块中的压缩与激发操作,以下简称SE模块,是一种基于注意力的特征图操作操作,SE模块首先对特征图进行压缩操作,在通道维度方向上进行全局平均池化操作(global average pooling),得到特征图通道维度方向的全局特征。然后对全局特征进行激发操作,使用激活比例(R,该比例为浮点数)乘全局特征维数(C)个1x1的卷积对其进行卷积(原方法使用全连接层),学习各个通道间的关系,再通过sigmoid激活函数得到不同通道的权重,最后乘以原来的特征图得到最终特征。本质上,SE模块是在通道维度上做(注意力)attention或者(门控制)gating操作,这种注意力机制让模型可以更加关注信息量最大的通道特征,而抑制那些不重要的通道特征。另外一点是SE模块是通用的,这意味着其可以嵌入到现有的其它网络架构中。其结构如图1所示,注意在移动翻转瓶颈卷积模块中,与激活比例相乘的是移动翻转瓶颈卷积模块的输入通道维度,而不是模块中深度卷积后的输出通道维度。

在这里插入图片描述

图1 压缩与激发模块结构示意图

你已经了解了EfficientNet的核心模块,接下来,我们将进一步了解EfficientNet-B0的结构,它由16个移动翻转瓶颈卷积模块,2个卷积层,1个全局平均池化层和1个分类层构成。其结构如图2所示,图中不同的颜色代表了不同的阶段。
在这里插入图片描述

图2 EfficientNet-B0结构图

第一阶段,对输入的224x224x3的图像按顺序进行以下操作得到第一阶段的结果:

  1. 卷积(卷积核为32核3×3×3,步长为2×2,填充为“same”即输出的宽和高缩小一半),该卷积运算的输出是一个维度为(112×112×32)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计864(32x3x3x3)个。

  2. 批归一化层(Batch Normalization,BN),该层输入为(112×112×32)的特征图,故该层含参数总数为128个(32x4),其中需要训练学习的参数为64个。

  3. Swish激活函数

第一阶段,总计参数128+864=992个,需要训练学习的参数928个。

第二阶段,对前一阶段输出的112x112x32的特征图进行移动翻转瓶颈卷积(扩张比例为1,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,无连接失活和连接跳越),并输出第二阶段的结果:

  1. 由于扩张比例为1,故跳过一开始的逐点卷积,直接进行深度卷积(卷积核为32核3×3×3,步长为1×1,填充为“same”即输出的宽和高不变)。深度卷积输出是一个维度为(112×112×32)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计288(32x3x3x1)个。

  2. 批归一化层(Batch Normalization,BN),该层输入为(112×112×32)的特征图,故该层含参数总数为128个(32x4),其中需要训练学习的参数为64个。

  3. Swish激活函数。

  4. 全局平均池化层(global average pooling),该层在通道维度方向上进行全局平均池化,输出为(1x1x32)的特征图。

  5. 卷积(压缩与激发模块中的第一个卷积,卷积核为8核1x1x32,步长为1×1,填充为“same”即输出的宽和高不变),该卷积运算的输出是一个维度为(1×1×8)的特征图。因该层包含偏置项,故该层需要训练学习的参数共计264(8x1x1x32+8)个。

  6. Swish激活函数。

  7. 卷积(压缩与激发模块中的第二个卷积,卷积核为32核1x1x8,步长为1×1,填充为“same”即输出的宽和高不变),该卷积运算的的输出是一个维度为(1×1×32)的特征图。因该层包含偏置项,故该层需要训练学习的参数共计288(32x1x1x8+32)个。

  8. Sigmoid激活函数

  9. 与步骤3)的结果相乘,得到112x112x32的特征图。

  10. 逐点卷积(卷积核为16核1×1×32,步长为1×1,填充为“same”即输出的宽和高不变)该卷积运算的输出是一个维度为(112×112×16)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计512(16x1x1x32)个。

  11. 批归一化层(Batch Normalization,BN),该层输入为(112×112×16)的特征图,故该层含参数总数为64个(16x4),其中需要训练学习的参数为32个。

第二阶段,总计参数288+128+264+288+512+64=1544个,需要训练学习的参数1448个。

第三阶段,对前一阶段输出的112x112x16的特征图进行两次移动翻转瓶颈卷积,第一个(扩张比例为6,深度卷积核大小为3x3,核步长为2x2,包含压缩与激发操作,无连接失活核连接跳越),第二个(扩张比例为6,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,有连接失活和连接跳越),并输出第二阶段的结果:

  1. 扩张比例为6的逐点卷积(卷积核为96核1×1×16,步长为1×1,填充为“same”即输出的宽和高不变)该卷积运算的输出是一个维度为(112×112×96)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计1536(96x1x1x16)个。

  2. 批归一化层(Batch Normalization,BN),该层输入为(112×112×96)的特征图,故该层含参数总数为384个(96x4),其中需要训练学习的参数为192个。

  3. Swish激活函数

  4. 深度卷积(卷积核为6核3×3×16,步长为2×2,填充为“same”即输出的宽和高缩小一半)。深度卷积输出是一个维度为(56×56×96)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计864(96x3x3x1)个。

  5. 批归一化层(Batch Normalization,BN),该层输入为(56×56×96)的特征图,故该层含参数总数为384个(96x4),其中需要训练学习的参数为192个。

  6. Swish激活函数。

  7. 全局平均池化层(global average pooling),该层在通道维度方向上进行全局平均池化,输出为(1x1x96)的特征图。

  8. 卷积(压缩与激发模块中的第一个卷积,卷积核为4核1x1x96,步长为1×1,填充为“same”即输出的宽和高不变),该卷积运算的输出是一个维度为(1×1×4)的特征图。因该层包含偏置项,故该层需要训练学习的参数共计388(4x1x1x96+4)个。

  9. Swish激活函数。

  10. 卷积(压缩与激发模块中的第二个卷积,卷积核为96核1x1x4,步长为1×1,填充为“same”即输出的宽和高不变),该卷积运算的的输出是一个维度为(1×1×96)的特征图。因该层包含偏置项,故该层需要训练学习的参数共计480(96x1x1x4+96)个。

  11. Sigmoid激活函数

  12. 与步骤6)的结果相乘,得到56x56x96的特征图。

  13. 逐点卷积(卷积核为24核1×1×96,步长为1×1,填充为“same”即输出的宽和高不变)该卷积运算的输出是一个维度为(56×56×24)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计2304(24x1x1x96)个。

  14. 批归一化层(Batch Normalization,BN),该层输入为(56×56×24)的特征图,故该层含参数总数为96个(24x4),其中需要训练学习的参数为48个。

  15. 该阶段的第二个移动翻转瓶颈卷积(扩张比例为6,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,有连接失活和连接跳越),其中结尾的连接失活和连接跳越不含参数。除了深度卷积的步长发生了变化外,其余操作与第一个移动翻转瓶颈卷积相同,输出为(56x56x24)故第二个移动翻转瓶颈卷积的参数总和为3456+576+1296+576+870+1008+3456+96=11334,其中需要训练的参数为10701个。

第三阶段,总计参数17770个,需要训练学习的参数16705个。

第四阶段,对前一阶段输出的56x56x24的特征图进行两次移动翻转瓶颈卷积,第一个(扩张比例为6,深度卷积核大小为5x5,核步长为2x2,包含压缩与激发操作,无连接失活核连接跳越),第二个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活和连接跳越),输出是一个28x28x40的特征图。总计参数48336个,需要训练学习的参数46640个。

第五阶段,对前一阶段输出的28x28x40的特征图进行三次移动翻转瓶颈卷积,第一个(扩张比例为6,深度卷积核大小为3x3,核步长为2x2,包含压缩与激发操作,无连接失活核连接跳越),第二个(扩张比例为6,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),第三个(扩张比例为6,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),输出是一个14x14x80的特征图。总计参数248210个,需要训练学习的参数242930个。

第六阶段,对前一阶段输出的14x14x80的特征图进行三次移动翻转瓶颈卷积,第一个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,无连接失活核连接跳越),第二个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),第三个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),输出是一个14x14x112的特征图。总计参数551116个,需要训练学习的参数543148个。

第七阶段,对前一阶段输出的14x14x112的特征图进行四次移动翻转瓶颈卷积,第一个(扩张比例为6,深度卷积核大小为5x5,核步长为2x2,包含压缩与激发操作,无连接失活核连接跳越),第二个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),第三个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),第四个(扩张比例为6,深度卷积核大小为5x5,核步长为1x1,包含压缩与激发操作,有连接失活核连接跳越),输出是一个7x7x192的特征图。总计参数2044396个,需要训练学习的参数2026348个。

第八阶段,对前一阶段输出的7x7x192的特征图进行一次移动翻转瓶颈卷积(扩张比例为6,深度卷积核大小为3x3,核步长为1x1,包含压缩与激发操作,无连接失活核连接跳越)输出是一个7x7x320的特征图。总计参数722480个,需要训练学习的参数717232个。

第九阶段,对输入的7x7x320的图像按顺序进行以下操作得到模型最终的记过:

  1. 卷积(卷积核为1280核1×1×320,步长为1×1,填充为“same”即输出的宽和高不变),该卷积运算的输出是一个维度为(7×7×1280)的特征图。因该层不含偏置项,故该层需要训练学习的参数共计409600(1280x1x1x320)个。

  2. 批归一化层(Batch Normalization,BN),该层输入为(7×7×1280)的特征图,故该层含参数总数为5120个(1280x4),其中需要训练学习的参数为2560个。

  3. Swish激活函数

  4. 全局平均池化层(global average pooling),该层在通道维度方向上进行全局平均池化,输出为(1x1x1280)的特征图。

  5. 随机失活dropout

  6. 全连接层,该层有1000个神经元。因该层包含偏置项,总参数个数为1281000(1000x1280+1000)

  7. Softmax激活函数,输出分类结果。

第九阶段,总计参数1695720个,需要训练学习的参数1693160个。

除了EfficientNet-B0外,EfficientNet系列还有其它7个网络(EfficientNet-B1,EfficientNet-B2,EfficientNet-B3,EfficientNet-B4,EfficientNet-B5,EfficientNet-B6,EfficientNet-B7),这些网络均是谷歌大脑团队通过神经网络架构搜索在不同的运算次数和运行内存限制下,在EfficientNet-B0的参数基础上对模型进行缩放得到的。主要涉及三个参数深度参数、广度参数和输入分辨率参数,通过这三个参数来控制模型的缩放。其中深度参数通过与EfficientNet-B0中各阶段的模块重复次数相乘,得到更深层的网络架构;广度系数通过与EfficientNet-B0中各卷积操作输入的核个数相乘,得到表现能力更强的网络模型;输入分辨率参数控制的则是网络的输入图片的长宽大小。


http://www.ppmy.cn/news/724823.html

相关文章

b0值

b0像是什么,b0像的去脑壳后的mask如何得到? b0像是DTI图像在扫面生成时b0是对应的图像,大家在拿到每个人的DTI数据时每个被试都有三个文件,bval文件,bvec文件和一个原始图像文件。这个图像文件可以用fslview打开,bval…

B0宏

在编译android平台用的ffmpeg时,抛出这样一个错误: 这句代码怎么看都找出有毛病,为什么B0会报错? 翻看aaccoder.c,也没有发现问题。为什么B0就成了一个常量数字,这里只有一个可能,B0在某处被宏定…

6-SIM数据交互之-B0(READ BINARY)

B0-READ BINARY B0即透明EF里面的内容binary(二进制),该指令一般在C0之后执行,在C0返回的fcp里面可以判断到该文件下是否存在binary及binary的长度,如果存在即可用B0需要读的字节长度。B0里面存的内容一般比较重要,如我们最常用的…

java中各种锁概念介绍,乐观锁 ,悲观锁 ,公平锁,非公平锁,可重入锁,读写锁,共享锁,自旋锁,偏向锁,轻量级锁,重量级锁等

乐观锁 乐观锁是一种乐观思想,即认为读多写少,遇到并发写的可能性低,每次去拿数据的时候都认为 别人不会修改,所以不会上锁,但是在更新的时候会判断一下在此期间别人有没有去更新这个数 据,采取在写时先读…

MySQL中的两种特殊插入方式

MySQL中的两种特殊插入方式 更新插入(on duplicate key update) 代码案例 PointMapper.java Mapper public interface PointMapper {/*** on duplicate key update ,是基于主键 或唯一索引 ,已存在数据则执行更新,不存在则执行插入*/int updateBatchByOdku(List…

【Leetcode】28. 找出字符串中第一个匹配项的下标

一、题目 1、题目描述 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 示例1: 输入:haystack = "sadbutsad", needle = "sa…

mysql binlog

简介 binlog用于记录数据库执行的写入性操作(不包括查询)信息,以二进制的形式保存在磁盘中。binlog是mysql的逻辑日志,并且由Server层进行记录,使用任何存储引擎的mysql数据库都会记录binlog日志。 binlog是通过追加的方式进行写入的&#…

日紫白飞星算法_年月日时紫白飞星法——紫白(入中)计算办法

年、月、日、时紫白飞星法 年飞星起例诀: 歌诀:年上吉星论甲子,逐年星逆中宫取, 上中下作三元汇,一上四中七下使。 又诀 上元甲子一白起, 中元四绿推甲子. 下元七赤兑位寻, 逐年星逆中宫是。 (逆数顺飞) 上元:(65-柱数)除9之余数.中元:(68-柱数)除9之余数.下元:(62-…