mmdet3d training 流程

news/2024/11/2 17:38:45/

一般大家的pytorch训练代码都比较简洁,mmdet3d为了支持扩展性,把代码进行了很多的抽象和封装,大大降低了可读性。现在简单理一下其training的代码执行逻辑。
实际使用的时候肯定是train几个epoch之后eval一次的,这里只考虑training

训练前配置

配置一下data的config

dataset_type = 'CustomWaymoDataset'
data_root = '/localdata_ssd/waymo_ssd_train_only/kitti_format/' 
data = dict(samples_per_gpu=1,workers_per_gpu=4,train=dict(type='RepeatDataset',times=1,dataset=dict(type=dataset_type,data_root=data_root,num_views=num_views,ann_file=data_root + 'waymo_infos_train.pkl',split='training',pipeline=train_pipeline,modality=input_modality,classes=class_names,test_mode=False,# we use box_type_3d='LiDAR' in kitti and nuscenes dataset# and box_type_3d='Depth' in sunrgbd and scannet dataset.box_type_3d='LiDAR',# load one frame every five framesload_interval=5)),

训练

因为大家共享GPU,难免一台机器有其他人在使用部分gpu,因此需要export CUDA_VISIBLE_DEVICES=1,2,3
为dist_train.sh指定配置文件和gpu数量,就能开始训练了
bash tools/dist_train.sh projects/configs/detr3d/detr3d_res101_gridmask_waymo.py 8

tools/dist_train.sh

python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}

分布式运行tools/train.py

tools/train.py

def parse_args()就是传递命令行参数,注意到这里很多参数和config.py里设置的是重复的,习惯直接把参数设置在config.py里。

custom plugin import

如果我们自己写了自定义的模型、dataset等代码,要以plugin的形式嵌入mmdet3d,那么在这里要先import进来。我的理解,import module的时候,python会把整个module的代码跑一遍:首先找到__init__.py,我们提前在里面写好了要import哪些submodule,然后python继续import submodule,比如说,detr3d的detector:

from mmdet.models import DETECTORS
@DETECTORS.register_module()
class Detr3D(MVXTwoStageDetector):"""Detr3D."""...

这个时候,因为有@DETECTORS.register_module()的存在,detr3d就会被mmlab里的registry注册到detectors这个module底下,这样对于mmdet的所有代码来说,detr3d这个类都是可访问的了。之后用mmdet3d.models.build_model()方法,他也能找到这个类,实例化一个detr3d出来。不过需要注意的是,如果把customdataset注册到mmdet里,mmdet3d的dataset builder会找不到我写的dataset类,必须放到mmdet3d里才行。我还没找出原因。还需要进一步理解python的机制。

initialize

之后就会根据config和args做初始化的工作,比如初始化logger,dataset,model等。最后调用mmdet3d.apis.train_model(),里面进一步调用下面的mmdet.apis.train_detector

mmdetection\mmdet\apis\train.py:train_detector()

进一步初始化,初始化dataloader,optimizer,runner等。最后调用runner.run(data_loaders, cfg.workflow),进入真正的训练。
runner用于管理整个training procedure,具体原理见mmdet官方教程(我看的知乎)。

mmcv\mmcv\runner\epoch_based_runner.py

runner挂了很多hook,用于定义训练不同阶段的行为,比如训练前后要保存什么信息,每个epoch或者iter前后要做些什么,比如learning rate调整和根据gradient来optimize weights。通常来说每train几个epoch就要eval一次,我们配置参数之后,runner也会帮你做。

runner.train()

定义了一个epoch会做哪些事,epoch前后会call对应的hook,iter前后也会。

 self.model.train()self.mode = 'train'self.data_loader = data_loaderself._max_iters = self._max_epochs * len(self.data_loader)self.call_hook('before_train_epoch')time.sleep(2)  # Prevent possible deadlock during epoch transitionfor i, data_batch in enumerate(self.data_loader):self._inner_iter = iself.call_hook('before_train_iter')self.run_iter(data_batch, train_mode=True, **kwargs)self.call_hook('after_train_iter')self._iter += 1self.call_hook('after_train_epoch')self._epoch += 1

runner.run_iter()

核心就是

outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)

这样,data就最终进入model的train_step了,比如fcos3d,就会进入继承自父类mmdet.models.detectors.BaseDetector的train_step,进而进行forward,loss等操作,注意这里传进去的optimizer,对于fcos3d来说是没用的,可能将来或者其他地方会用到。


http://www.ppmy.cn/news/641587.html

相关文章

模型的保存与加载与多gpu的模型保存和加载

模型保存与加载 模型的保存与加载方式 模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再…

D3D中数据从显存、内存相互拷贝的时间对比

显存到内存(分辨率)缩放时间(stretchRect)LockRect(调用GetRenderTargetData,空)LockRect(用memcpy复制数据)400*300015~17 ms数据已经到内存,此项无效800*600016~18 ms数据已经到内存,此项无效1920*10800 20~21ms数据已经到内存&…

mmdetection3d 训练

本文为博主原创文章,未经博主允许不得转载。 本文为专栏《python三维点云从基础到深度学习》系列文章,地址为“https://blog.csdn.net/suiyingy/article/details/124017716”。 本节以SECOND算法为例,简要介绍mmdetection3d second算法训练过程,含数据和python源码详细介绍…

3D模型格式的一点总结

通俗来说,你可以把“格式”理解成基于同一规范的技术表征,也可以再简化点把它看成一种分类方式。对于3D模型来说,格式更是种类繁多。不同应用领域的、不同功能属性的,加密的、独有的、通用的,让人眼花缭乱。 目录 我的…

模型取取 很卡、慢

在3dmax里要使用到一些插件,从别的场景里复制模型到当前场景,这种操作比较方便,不用使用3d自带的比较麻烦的合并导入功能,但是有的时候,使用插件复制模型,会比较卡,有的时候会没有响应&#xff…

Smart3D运行过程中遇到的问题(持续更新)

写在前面:本文是基于我自己的理解而进行解释并找到的解决办法,因此文中解释的原因不一定正确或解决方式最简单合适。 1.问题:tile刚运行就报错:failed to create “C:\Users\ADMINISTRATOR\AppData\Local\Temp\Bentley\ContextCap…

blender物体缩放倍数和保存时物体炸开

问题一:物体缩放倍数奇奇怪怪 有时候看着两个blender的模型场景,比如一栋楼。我复制A.beldner里面的门粘贴到B.blender里面来,啥也看不到,通过按【S】进行缩放之后看到了,可是为啥会放大呢? 有时候将blend…

MMdetection3d环境搭建、使用MMdetection3d做3D目标检测训练自己的数据集、测试、可视化,以及常见的错误

MMdetection3d环境搭建、使用MMdetection3d做3D目标检测训练自己的数据集、测试、可视化,以及常见的错误 1 mmdetection3d环境搭建与测试1.1 从docker开始搭建环境1.1.1 开始从docker环境搭建 1.1.2 测试demo1.2 直接在外部安装mmdetection3d环境1.2.1 创建并激活虚…