一般大家的pytorch训练代码都比较简洁,mmdet3d为了支持扩展性,把代码进行了很多的抽象和封装,大大降低了可读性。现在简单理一下其training的代码执行逻辑。
实际使用的时候肯定是train几个epoch之后eval一次的,这里只考虑training
训练前配置
配置一下data的config
dataset_type = 'CustomWaymoDataset'
data_root = '/localdata_ssd/waymo_ssd_train_only/kitti_format/'
data = dict(samples_per_gpu=1,workers_per_gpu=4,train=dict(type='RepeatDataset',times=1,dataset=dict(type=dataset_type,data_root=data_root,num_views=num_views,ann_file=data_root + 'waymo_infos_train.pkl',split='training',pipeline=train_pipeline,modality=input_modality,classes=class_names,test_mode=False,# we use box_type_3d='LiDAR' in kitti and nuscenes dataset# and box_type_3d='Depth' in sunrgbd and scannet dataset.box_type_3d='LiDAR',# load one frame every five framesload_interval=5)),
训练
因为大家共享GPU,难免一台机器有其他人在使用部分gpu,因此需要export CUDA_VISIBLE_DEVICES=1,2,3
为dist_train.sh指定配置文件和gpu数量,就能开始训练了
bash tools/dist_train.sh projects/configs/detr3d/detr3d_res101_gridmask_waymo.py 8
tools/dist_train.sh
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
分布式运行tools/train.py
tools/train.py
def parse_args()就是传递命令行参数,注意到这里很多参数和config.py里设置的是重复的,习惯直接把参数设置在config.py里。
custom plugin import
如果我们自己写了自定义的模型、dataset等代码,要以plugin的形式嵌入mmdet3d,那么在这里要先import进来。我的理解,import module的时候,python会把整个module的代码跑一遍:首先找到__init__.py,我们提前在里面写好了要import哪些submodule,然后python继续import submodule,比如说,detr3d的detector:
from mmdet.models import DETECTORS
@DETECTORS.register_module()
class Detr3D(MVXTwoStageDetector):"""Detr3D."""...
这个时候,因为有@DETECTORS.register_module()的存在,detr3d就会被mmlab里的registry注册到detectors这个module底下,这样对于mmdet的所有代码来说,detr3d这个类都是可访问的了。之后用mmdet3d.models.build_model()方法,他也能找到这个类,实例化一个detr3d出来。不过需要注意的是,如果把customdataset注册到mmdet里,mmdet3d的dataset builder会找不到我写的dataset类,必须放到mmdet3d里才行。我还没找出原因。还需要进一步理解python的机制。
initialize
之后就会根据config和args做初始化的工作,比如初始化logger,dataset,model等。最后调用mmdet3d.apis.train_model(),里面进一步调用下面的mmdet.apis.train_detector
mmdetection\mmdet\apis\train.py:train_detector()
进一步初始化,初始化dataloader,optimizer,runner等。最后调用runner.run(data_loaders, cfg.workflow)
,进入真正的训练。
runner用于管理整个training procedure,具体原理见mmdet官方教程(我看的知乎)。
mmcv\mmcv\runner\epoch_based_runner.py
runner挂了很多hook,用于定义训练不同阶段的行为,比如训练前后要保存什么信息,每个epoch或者iter前后要做些什么,比如learning rate调整和根据gradient来optimize weights。通常来说每train几个epoch就要eval一次,我们配置参数之后,runner也会帮你做。
runner.train()
定义了一个epoch会做哪些事,epoch前后会call对应的hook,iter前后也会。
self.model.train()self.mode = 'train'self.data_loader = data_loaderself._max_iters = self._max_epochs * len(self.data_loader)self.call_hook('before_train_epoch')time.sleep(2) # Prevent possible deadlock during epoch transitionfor i, data_batch in enumerate(self.data_loader):self._inner_iter = iself.call_hook('before_train_iter')self.run_iter(data_batch, train_mode=True, **kwargs)self.call_hook('after_train_iter')self._iter += 1self.call_hook('after_train_epoch')self._epoch += 1
runner.run_iter()
核心就是
outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
这样,data就最终进入model的train_step了,比如fcos3d,就会进入继承自父类mmdet.models.detectors.BaseDetector的train_step,进而进行forward,loss等操作,注意这里传进去的optimizer,对于fcos3d来说是没用的,可能将来或者其他地方会用到。