文章目录
- 模型的测试流程
- 1. AnchorHeadTemplate.generate_predicted_boxes部分
- 2. Detector3DTemplate.post_processing部分
- 3. KittiDataset.generate_prediction_dicts部分
- 4. KittiDataset.evaluation部分
模型的测试流程
对于模型来说,训练过程是为了计算构建损失训练模型的参数,验证过程是为了测试模型当前参数的效果。所以,对于模型结构来说需要分别为测试过程和训练过程进行分别规划。在点云的3d检测中,这里主要体现在dense_head预测层中。对于模型来说,其与训练流程的区别结构图如下:
- 对于dense_head处理的区别:
# 功能:构建PointPillar的dense head模块部分
class AnchorHeadSingle(AnchorHeadTemplate):def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,predict_boxes_when_training=True, **kwargs):super().__init__( # 基类没有传参input_channels和voxel_sizemodel_cfg=model_cfg, num_class=num_class,class_names=class_names, grid_size=grid_size,point_cloud_range=point_cloud_range,predict_boxes_when_training=predict_boxes_when_training)...... def forward(self, data_dict):......# 训练过程if self.training:targets_dict = self.assign_targets( # 获取gt信息gt_boxes=data_dict['gt_boxes'])self.forward_ret_dict.update(targets_dict) # 此时记录gt信息以及预测信息,来进行后续的loss计算# 测试过程if not self.training or self.predict_boxes_when_training:batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(batch_size=data_dict['batch_size'],cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds)data_dict['batch_cls_preds'] = batch_cls_predsdata_dict['batch_box_preds'] = batch_box_predsdata_dict['cls_preds_normalized'] = Falsereturn data_dict # 返回更新后的data_dict
- 对于个模块处理后的算法流程区别:
# 功能:基于Detector3DTemplate构建PointPillar算法结构
class PointPillar(Detector3DTemplate):def __init__(self, model_cfg, num_class, dataset):"""Args:model_cfg: yaml配置文件的MODEL部分num_class: 类别数目(kitti数据集一般用3个类别:'Car', 'Pedestrian', 'Cyclist')dataset: 训练数据集"""super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) # 初始化基类# 网络的各处理模块已经存储在self中(vfe / map_to_bev / backbone_2d ...)self.module_list = self.build_networks() # 真正构建模型的处理函数,Detector3DTemplate的子函数def forward(self, batch_dict):# 各模块分别进行特征处理,更新batch_dict,然后将预测信息与gt信息保存在forward_ret_dict字典中来进行后续的损失计算for cur_module in self.module_list:batch_dict = cur_module(batch_dict)# 训练过程进行损失计算if self.training:loss, tb_dict, disp_dict = self.get_training_loss() # 损失计算ret_dict = {'loss': loss}return ret_dict, tb_dict, disp_dict# 测试过程进行后处理返回预测结果else:pred_dicts, recall_dicts = self.post_processing(batch_dict)return pred_dicts, recall_dicts......
下面会分别对这些核心函数模块进行记录。
1. AnchorHeadTemplate.generate_predicted_boxes部分
generate_predicted_boxes函数一开始的数据传入为:
最后将特征存储在字典中:
2. Detector3DTemplate.post_processing部分
在AnchorHeadTemplate.generate_predicted_boxes函数处理完之后更新的batch_dict字典就是这边后处理函数的输入。可以说,这个batch_dict一直贯穿着整个模型的前向传播过程,包括测试阶段的后处理部分。在训练过程中就用不到这个batch_dict。
3. KittiDataset.generate_prediction_dicts部分
这部分进行kitti数据的预测处理模块
4. KittiDataset.evaluation部分
这部分进行kitti数据的具体验证模块