人工智能学习07--pytorch18--目标检测:Faster RCNN源码解析(pytorch)

news/2024/11/30 2:49:21/

参考博客:
https://blog.csdn.net/weixin_46676835/article/details/130175898

VOC2012

在这里插入图片描述

1、代码的使用

  1. 查看pytorch中的faster-rcnn源码:
    在pytorch中导入:
import torchvision.models.detection.faster_rcnn

在这里插入图片描述
即可找到faster rcnn所实现的源码,但这只是代码的一部分,和训练相关的代码并不在此。
官方提示在pytorch的github上:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
找到在训练过程中使用的一些文件。
在这里插入图片描述
2. 视频里的代码注意事项:
在这里插入图片描述
https://github.com/pytorch/vision/tree/master/torchvision/models/detection
在这里插入图片描述

pip install pycocotools
pip install pycocotools-windows

在这里插入图片描述
在这里插入图片描述

  • backbone:特征提取网络,可以根据自己的要求选择。 在这里按照官方的样例使用了2个backbone:MobileNetv2、ResNet50+FPN。
  • network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块) 。 构建Faster R-CNN网络的一些模块。主要针对这里的文件进行讲解。
  • train_utils: 训练验证相关模块(包括cocotools)。 涉及训练网络的模块,pytorch官方给的。
  • my_dataset.py: 自定义dataset用于读取VOC数据集。 用于实现一个自定义的dataset。了解自定义数据集的原理后就可以去按照自己的需求创建自己的数据集,再编写相应脚本读取即可。
  • train_mobilenet.py: 以MobileNetV2做为backbone进行训练。 这里讲源码的时候主要以这套来讲,因为是单层的,即预测特征层只有一种。和Faster R-CNN基本上保持一致。但是准确率要低很多。
  • train_resnet50_fpn.py: 以resnet50+FPN做为backbone进行训练。 训练效果很好,实际中尽可能用。官方提供了resnet50+fpn的完整模型权重,利用这个进行迁移学习即可很快地得到一个自己的模型。
  • train_multi_GPU.py: 针对使用多GPU的用户使用。 需要在命令行窗口输入指令,不像前面train_mobilenet.py、train_resnet50_fpn.py这种直接在ide运行脚本即可。
  • predict.py: 简易的预测脚本,使用训练好的权重进行预测测试。
  • validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件。
  • pascal_voc_classes.json: pascal_voc标签文件。 pascal_voc类别信息:80个类别,对应整数标签(不从0开始,一般从1开始:目标检测中,0一般留给背景)
    在这里插入图片描述
    这里学习的代码自动去backbone文件夹下寻找所需权重。
    MobileNetV2 weights(下载后重命名为mobilenet_v2.pth,然后放到bakcbone文件夹下): https://download.pytorch.org/models/mobilenet_v2-b0353104.pth 只有backbone的权重,没有后面rpn、fastrcnn 的权重。
    在这里插入图片描述
    完整的模型权重:
    Resnet50 weights(下载后重命名为resnet50.pth,然后放到bakcbone文件夹下): https://download.pytorch.org/models/resnet50-0676ba61.pth
    ResNet50+FPN weights: https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
    注意,下载的预训练权重记得要重命名,比如在train_resnet50_fpn.py中读取的是fasterrcnn_resnet50_fpn_coco.pth文件, 不是fasterrcnn_resnet50_fpn_coco-258fb6c6.pth,然后放到当前项目根目录下即可。
    在这里插入图片描述
    **Pascal VOC2012 train/val数据集下载地址(没有测试集test):**http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    (弹幕:其实按道理来说train训练,val进行调优选择最好的,然后再用test进行测试)

使用ResNet50+FPN以及迁移学习在VOC2012数据集上得到的权重: 链接:https://pan.baidu.com/s/1ifilndFRtAV5RDZINSHj5w 提取码:dsz8

在这里插入图片描述

train_mobilenet.py 训练脚本

如何调用模型进行训练:
在这里插入图片描述

train_res50_fpn.py 训练脚本

同train_mobilenet.py差不多,但是在main前加了一些参数,显得正式一点。
在这里插入图片描述
用RestNet50作为骨干网络时下载的预训练权重包含了FPN和backbone结构,而采用mobilenet作为骨干网络时下载的预训练权重只是backbone的权重。

2、自定义DataSet

split_data.py
先略,等我听懂之后再写

参考pytorch样例

3、FasterRCNN框架

在这里插入图片描述
在这里插入图片描述

faster_rcnn_framework.py

  • class FasterRCNNBase:
    在这里插入图片描述
    roi_heads包括:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    original_image_sizes.append((val[0], val[1])) 记录最原始图像的size,最后得到输出后再映射回原图像,这样得到的目标的边界框才是正确的数值。
    在这里插入图片描述
    在这里插入图片描述
  • class FasterRCNN
    主要是在初始函数中定义一系列参数,用到在FasterRCNNBase中提到过的backbone、rpn、roi_heads模块。
    在这里插入图片描述

在这里插入图片描述
在RPN中,通过预测信息和anchor生成器生成一系列anchor,则可得到所有预测的proposal(可能有成千上万个),在输出前要对proposal进行过滤,过滤筛选之后才有NMS处理。
在这里插入图片描述
这两个一个在nms处理前,一个在nms处理后
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
如果传入的不是它,就会报错
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4、GeneralizedRCNNTransform

在这里插入图片描述
对传入的图像进行标准化处理,同时对传入的图像以及它的bounding box信息进行resize处理,并且打包成一个一个的batch输入网络进行正向传播

transform.py

  • class GeneralizedRCNNTransform
    在这里插入图片描述
  • def normalize
    在这里插入图片描述
  • def resize
    将图像与它所对应的boundingbox进行缩放处理。将图像的大小放在之前所设定的最大最小值范围内。
    参考:https://blog.csdn.net/weixin_46676835/article/details/130175898
    在这里插入图片描述
  • def forward GeneralizedRCNNTransform类的正向传播过程
    在这里插入图片描述
    在这里插入图片描述
  • def postprocess
    在这里插入图片描述
    将预测的bounding box信息映射回原始的图像尺寸当中。
    在这里插入图片描述

剩下的有时间再补上↓↓↓↓↓

5、RPN(上)

在这里插入图片描述
AnchorsGenerator
RPNHead

6、RPN(中)

在这里插入图片描述
正负样本采样
RPN损失计算

7、RPN(下)

8、ROIAlign、TwoMLPHead、FastRCNNPredictor

9、FastRCNN正负样本划分及采样

10、FastRCNN损失计算

11、预测结果后处理

12、预测结果映射回原尺度

13、换backbone(不带FPN)

14、换backbone(带FPN)

15、训练+预测

我是使用mobilenet那个训练脚本训练的,中间遇到了几个问题,记录一下。
训练脚本:

  1. torchvision版本与cuda版本不匹配的问题
    参考:https://blog.csdn.net/qq_41590635/article/details/112384718?spm=1001.2014.3001.5501 的具体方法
    在这里插入图片描述
    上面要用到的网站在https://zhuanlan.zhihu.com/p/401931724?utm_id=0这篇文章中有提及:https://pytorch.org/get-started/previous-versions/
    在这里插入图片描述
    服务器上用的是这个,要注意torch、torchvision、cuda版本的匹配问题(跟上面这一行一样就行了)
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
  1. 在提示报错的……\functional.py文件中,修改478行左右的代码。 加上indexing='ij'
    在这里插入图片描述
    参考文章:https://blog.csdn.net/qq_63378911/article/details/127759513

换了一个服务器

  1. No module named ‘torch._six‘
    在这里插入图片描述
    参考:https://blog.csdn.net/qq_24502827/article/details/130195645?spm=1001.2101.3001.6650.5&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-5-130195645-blog-130362406.235%5Ev36%5Epc_relevant_anti_vip&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-5-130195645-blog-130362406.235%5Ev36%5Epc_relevant_anti_vip&utm_relevant_index=6

测试脚本 predict.py:

  1. 忘记了,但是这里加False。好像是忽视迁移学习的数据不匹配问题。
    在这里插入图片描述
  2. 显示数据格式不匹配
    在这里插入图片描述
    在这里插入图片描述
    但是我检查过了,这里确实是一样的。
    说明问题出在别的地方。
    参考:https://blog.csdn.net/qq_40630902/article/details/119762723?spm=1001.2014.3001.5501
    在这里插入图片描述
    最后发现问题是出在了这里,改成图上这样就好了。
    在这里插入图片描述
    之前是直接用的路径weights_path。

http://www.ppmy.cn/news/88943.html

相关文章

【leetcode】!longest substring without repeating chars

参考资料:《剑指offer》,《程序员代码面试指南》 思路: 对每一个位置str[i]来说,找它的以str[i]为end、最长、无重复字符的子串 的过程 相当于 尽可能以str[i]为end, 向左扩, 直至扩到 以str[i-1]为end、最…

JWT(Json Web Token)的原理、渗透与防御

(关于JWT kid安全部分后期整理完毕再进行更新~2023.05.16) JWT的原理、渗透与防御 目录 JWT的原理、渗透与防御含义原理JWT的起源传统session认证问题token与session区别JWT的结构与内容 JWT的攻击和渗透敏感信息泄露空密钥破解密钥爆破CVE-2019-7644 J…

Spring Boot 3.x 系列【38】服务监控 | 自定义端点

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot版本3.0.5 源码地址:https://gitee.com/pearl-organization/study-spring-boot3 文章目录 1. 概述2. 自定义 Web 端点3. 控制器端点4. Servlet端点5. 扩展端点1. 概述 Spring Boot Actuator默认已提供了很多端点,…

【VR】手柄定位技术

1. 关于Quest Pro头显、控制器的规格分析(终篇)及Quest 3分辨率 (2022年07月29日)被认为是“Quest Pro”的高端一体机Project Cambria将于今年秋季正式发布。对于一直关注和分享所述设备情报的YouTuber布拉德利林奇(B…

亚马逊云科技构建Serverless数据分析战略

亚马逊云科技Amazon EMR是行业领先的大数据分析服务,适用于使用开源框架进行PB级数据处理、交互分析和机器学习,它可以更快地运行大数据应用并且成本不到本地解决方案成本的一半。通过性能优化且兼容开源API的Spark、Hive和Presto版本,洞察时…

【1】安装与配置tensorflow

常见深度学习框架市场占有率 1.创建虚拟环境 打开菜单栏里的 点击creat创建 2.激活虚拟环境 打开命令提示符,输入activate tensorflow 可以看到进入tensorflow环境: 3.更换源 为提高下载速度,执行以下命令: pip config set g…

理解并掌握 Linux 系统下的文件操作命令:mv 与 cp

在 Linux 系统中,文件操作是开发者和管理员必须要掌握的基本技能之一。文件操作包括对文件的创建、读取、修改、删除等。其中,mv 和 cp 命令是常用的文件操作命令,但很多人在使用时常常混淆。本篇文章旨在阐述 mv 和 cp 命令的使用区别和特点…

KD7742电气安规综合测试仪

一、产品简介 KD7742电气安规综合测试仪具有交/直流耐压、绝缘电阻等项目的测试分析功能,能显示电压、电流和电阻的波形图以及趋势图,以便更直观的监测分析绝缘性能和绝缘崩溃时的各项指标,适用于高要求的测试分析场合。 产品具有测试参数范围…