Faster RCNN网络数据流总结

news/2024/11/29 3:40:11/

前言

在学习Faster RCNN时,看了许多别人写的博客。看了以后,对Faster RCNN整理有了一个大概的了解,但是对训练时网络内部的数据流还不是很清楚,所以在结合这个版本的faster rcnn代码情况下,对网络数据流进行总结。以便自己更好地掌握Faster rcnn。

训练时的数据流

在这个版本的代码中,训练时的batch_size为1。原论文中的网络架构如下所示:
在这里插入图片描述

1 ◯ \textcircled{\scriptsize 1} 1 网络输入

第一部分是网络的输入。网络的输入是一个任意大小的图像,但是在被送入网络之前,会经过一个缩放操作,然后进行normalize。对图像进行缩放的同时,也要对gt_bbox(ground truth bounding box,真实边界框)进行同样的缩放。
具体是怎么缩放的呢?参考这里的代码。

def preprocess(img, min_size=600, max_size=1000):# img: 输入图像# min_size: 图像放缩的最小大小# max_size: 图像放缩的最大大小C, H, W = img.shapescale1 = min_size / min(H, W)scale2 = max_size / max(H, W)scale = min(scale1, scale2)img = img / 255.# resize缩放大小  长和宽等比例缩放img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect',anti_aliasing=False)

这样的等比例缩放方式,结果就是要么原图较长的边被放大为1000,要么原图较短的边被放大为600。整体上来看是设定了一个放大后的最大最小范围。因为batch_size为1,所以每一张图像缩放后的大小可以不一样,如果batch_size不为1,那么这一个batch内的所有图像缩放后的大小就必须一样。在接下来的讨论中我们忽略batch维度(因为batch是1)

2 ◯ \textcircled{\color{green}\scriptsize 2} 2 特征提取网络

第二部分是特征提取模块。这里的特征提取网络是VGG16,只不过去掉了最后的几层全连接。这里感觉唯一要注意的地方就是,输入图像经过VGG16,大小缩小了16倍(因为有4个池化层),维度增加到了512维度。
如果输入图像 I i n p u t I^{input} Iinput的大小是 [ 3 , x , y ] \left[3,x,y\right] [3,x,y],那么经过特征提取的特征图 I f e a t u r e I^{feature} Ifeature的大小是 [ 512 , x 16 , y 16 ] \left[512, \frac{x}{16},\frac{y}{16}\right] [512,16x,16y]

3 ◯ \textcircled{\color{purple}\scriptsize 3} 3 RPN网络

RPN网络的输入是特征图,先经过通道数为512的3x3卷积,输出仍为 [ 512 , x 16 , y 16 ] \left[512, \frac{x}{16},\frac{y}{16}\right] [512,16x,16y]
右边这个分支为通道数为36(36是因为每个点有9个anchor,每个anchor有4个坐标)的1x1卷积,输出为 [ 36 , x 16 , y 16 ] \left[36, \frac{x}{16},\frac{y}{16}\right] [36,16x,16y],然后对其进行reshap为 [ a n c h o r 的总数 , 4 ] \left[ anchor的总数,4\right] [anchor的总数,4]大小,记为rpn_loc。
左边这个分支为通道数为18(18是因为每个点有9个anchor,每个anchor要么是背景要么是前景,两种可能)的1x1卷积,输出为 [ 18 , x 16 , y 16 ] \left[18, \frac{x}{16},\frac{y}{16}\right] [18,16x,16y]。然后对其经过softmax处理,最终的输出大小为 [ a n c h o r 的总数 , 2 ] \left[anchor的总数,2\right] [anchor的总数,2],记为rpn_score。

在这里插入图片描述
上述这点清楚以后,我们接下来重点关注RPN网络是如何计算损失的,称之为 L o s s R P N Loss^{RPN} LossRPN。我们都知道,计算loss需要网络输出值和标签值,现在网络输出值已经有了,那么标签值从何而来呢?
从上图可以看到有一个AnchorTargeCreator模块,这个模块的输入是我们产生anchor和gt_bbox,计算出anchor与gt_bbox的真实偏差gt_rpn_loc和该anchor到底负责的是背景还是前景gt_rpn_label。我们就分别将gt_rpn_loc和gt_rpn_label作为标签值与rpn_score、rpn_score计算损失,两个损失之和即为 L o s s R P N Loss^{RPN} LossRPN。损失的具体计算公式这里我们不谈。

在bbuf大佬的解读里,“AnchorTargetCreator 就是将 20000 多个候选的 Anchor 选出 256 个 Anchor 进行分类和回归。”代码里也是采样出了256个样本,但是最后返回的真实标签值是所有的anchor大小,而不是256大小。

ProposalCreator模块的含义如下:
在这里插入图片描述
综上,rpn网络除了自身反向传播训练之外,还通过ProposalCreator模块输出2000个anchor。

4 ◯ \textcircled{\color{blue}\scriptsize 4} 4 ProposalTargetCreator模块

ProposalCreator模块输出2000个ROIS并不全部都使用,经过ProposalTargetCreator模块的筛选(通过与gt_bbox的IOU进行筛选)产生正负一共128个rois。同时输出这128个rois的gt_label和gt_loc。

5 ◯ \textcircled{\color{blue}\scriptsize 5} 5 ROI pooling

这里的ROI pooling和fast rcnn中的是一样的,它的输入是特征图128个rois。ROI Pooling将这些不同尺寸的区域全部pooling到同一个尺度(7x7)上。ROP pooling的输出输入给classifier。

6 ◯ \textcircled{\color{blue}\scriptsize 6} 6 classifier

这里的classifier如下图紫色框出的所示。
在这里插入图片描述
这块的全连接网络可以借用VGG16的全连接网络,代码中也是这么做的。
21代表总共有21类,每个anchor属于每个类的概率,输出为 [ 128 , 21 ] [128,21] [128,21];84 = 21 *4,对每个类别都会有一个坐标信息,输出为 [ 128 , 84 ] [128,84] [128,84],然后分别和gt_label、gt_loc计算损失后相加即为classifier的损失。
suppress为推理时的非极大值抑制,训练时用不到。

反向传播

综上所述,我们将rpn网络的损失和classifier的损失相加,然后进行反向传播即可更新参数。
最后放上BBuf大佬总结的faster rcnn的网络流程图。
在这里插入图片描述
本人才识浅薄,若博文中有不正确的地方,欢迎大家进行批评指正,谢谢。
参考连接:giantpandacv
simple-faster-rcnn-pytorch


http://www.ppmy.cn/news/1051600.html

相关文章

LangChain + Streamlit + Llama:将对话式AI引入本地机器

推荐:使用 NSDT场景编辑器 助你快速搭建可二次编辑的3D应用场景 什么是LLMS? 大型语言模型 (LLM) 是指能够生成与人类语言非常相似的文本并以自然方式理解提示的机器学习模型。这些模型使用包括书籍、文章、网站和其他来源在内的…

读取/加载 properties/yml 配置文件

大家好 , 我是苏麟 , 今天带来一个简单好用的东西 . 读取/加载 properties/yml配置文件 基于PropertiesConfiguration读取配置文件 引入依赖 <!--加载yml资源--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-b…

代码随想录算法训练营day39 | LeetCode 62. 不同路径 63. 不同路径 II

2. 不同路径&#xff08;题目链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台&#xff09; 思路&#xff1a;整体思路是一样的&#xff0c;找到状态转移公式和初始化处理。不过这题开始dp数组变成了二维&#xff0c;需要着重考虑初始…

storm集群搭建

升级步骤 1.升级包上传 1.1上传apache-storm-2.2.0.tar.gz包 创建对应升级目录 将升级包apache-storm-2.2.0.tar.gz上传到新创建的目录下 1.2 执行解压命令tar -zxvf apache-storm-2.2.0.tar.gz1.3 将解压后的文件额外复制到其他服务器上 cp -r apache-storm-2.2.0 apache-…

springboot后端返回图片,vue前端接收并显示的解决方案

后端图片数据返回 后端通过二进制流的形式&#xff0c;写入response中 controller层 /*** 获取签到二维码*/GetMapping("/sign-up-pict")public void signUpPict(Long id, Long semId, HttpServletResponse response) throws NoSuchAlgorithmException {signUpServ…

Linux---用户权限管理

权力下放 sudo工具&#xff0c;可以将root的权限下放到普通用户&#xff0c;它允许系统管理员分配给普通用户一些合理的“权力”&#xff0c;让他们执行一些只有超级用户或其他特许用户才能完成的任务&#xff08;主要体现为命令&#xff09;&#xff0c;比如&#xff1a;运行…

innovus加decap如何避免drc问题

我正在「拾陆楼」和朋友们讨论有趣的话题&#xff0c;你⼀起来吧&#xff1f; 拾陆楼知识星球入口 加decap常见的问题就是跟regular wire产生drc&#xff0c;调试如下命令能够避免这类drc。 setFillerMode -add_fillers_with_drc false -fitGap false -ecoMode true

Echarts面积图2.0(范围绘制)

代码&#xff1a; // 以下代码可以直接粘贴在echarts官网的示例上 // 范围值 let normalValue {type: 内部绘制,minValue: 200,maxValue: 750 } // 原本的绘图数据 let seriesData [820, 932, 901, 934, 1290, 1330, 1320] let minData Array.from({length: seriesData.len…