目前文档只包含outputs = model(inputs,coord_ranges,calibs,K=50,mode=‘test’)之后,前向推理的源码解析,附带有测试程序
DEVIANT: Depth EquiVarIAnt NeTwork for Monocular 3D Object Detection
githubs = https://github.com/abhi1kumar/DEVIANT
核心思想:
此算法输出目标的2D坐标和3D坐标
总共有两个坐标系:image坐标系(2D)和相机坐标系(3D),要弄清楚网络输出值具体是哪个坐标系下的。
通过heatmap图确定目标在image坐标系上的位置和类别,再计算出
目标2D image坐标系中心点的偏移,得到2D image坐标系上的目标尺寸。网络输出3D size大小的偏移(相机坐标系下),根据预设的3D mean size大小直接得出目标3D尺寸。网络输出相机坐标系下目标的depth,也就是目标的3D中心点z值,最后利用相机的内外参数,结合z坐标,算出相机坐标系下目标的中心点x,y值。网络输出相机坐标系下目标的角度heading,将角度[-pai,pai]拆分了num_heading_bin=12个bin值,再做出精调,算出每个bin值的偏移量,得到alpha,最后ry = calibs.alpha2ry(alpha, x)得到真实目标角度。
推理测试脚本
demo_inference_zs.py
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)import yaml
import logging
import argparse
import torch
import numpy as np
import randomfrom lib.helpers.dataloader_helper import build_dataloader
from lib.helpers.model_helper import build_model
from lib.helpers.optimizer_helper import build_optimizer
from lib.helpers.scheduler_helper import build_lr_scheduler
from lib.helpers.trainer_helper import Trainer
from lib.helpers.tester_helper import Tester
from lib.datasets.kitti_utils import get_affine_transform
from lib.datasets.kitti_utils import Calibration
from lib.datasets.kitti_utils import compute_box_3d
from lib.datasets.waymo import affine_transformfrom lib.helpers.save_helper import load_checkpoint
from lib.helpers.decode_helper import extract_dets_from_outputs
from lib.helpers.decode_helper import decode_detections
from lib.helpers.rpn_util import *from datetime import datetimeparser = argparse.ArgumentParser(description='implementation of DEVIANT')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--config', type=str, default = 'experiments/config.yaml')
parser.add_argument('--resume_model', type=str, default=None)
args = parser.parse_args()def create_logger(log_file):log_format = '%(asctime)s %(levelname)5s %(message)s'logging.basicConfig(level=logging.INFO, format=log_format, filename=log_file)console = logging.StreamHandler()console.setLevel(logging.INFO)console.setFormatter(logging.Formatter(log_format))logging.getLogger().addHandler(console)return logging.getLogger(__name__)def init_torch(rng_seed, cuda_seed):"""Initializes the seeds for ALL potential randomness, including torch, numpy, and random packages.Args:rng_seed (int): the shared random seed to use for numpy and randomcuda_seed (int): the random seed to use for pytorch's torch.cuda.manual_seed_all function"""# seed everythingos.environ['PYTHONHASHSEED'] = str(rng_seed)torch.manual_seed(rng_seed)np.random.seed(rng_seed)random.seed(rng_seed)torch.cuda.manual_seed(cuda_seed)torch.cuda.manual_seed_all(cuda_seed)# make the code deterministictorch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsedef pretty_print(name, input, val_width=40, key_width=0):"""This function creates a formatted string from a given dictionary input.It may not support all data types, but can probably be extended.Args:name (str): name of the variable rootinput (dict): dictionary to printval_width (int): the width of the right hand side valueskey_width (int): the minimum key width, (always auto-defaults to the longest key!)Example:pretty_str = pretty_print('conf', conf.__dict__)pretty_str = pretty_print('conf', {'key1': 'example', 'key2': [1,2,3,4,5], 'key3': np.random.rand(4,4)})print(pretty_str)orlogging.info(pretty_str)"""pretty_str = name + ': {\n'for key in input.keys(): key_width = max(key_width, len(str(key)) + 4)for key in input.keys():val = input[key]# round values to 3 decimals..if type(val) == np.ndarray: val = np.round(val, 3).tolist()# difficult formattingval_str = str(val)if len(val_str) > val_width:# val_str = pprint.pformat(val, width=val_width, compact=True)val_str = val_str.replace('\n', '\n{tab}')tab = ('{0:' + str(4 + key_width) + '}').format('')val_str = val_str.replace('{tab}', tab)# more difficult formattingformat_str = '{0:' + str(4) + '}{1:' + str(key_width) + '} {2:' + str(val_width) + '}\n'pretty_str += format_str.format('', key + ':', val_str)# close root objectpretty_str += '}'return pretty_str
def np2tuple(n):return (int(n[0]), int(n[1]))# def main():
# load cfg
config = "/media/data/zs/project/detection3D/github/DEVIANT/code/experiments/run_test.yaml"
cfg = yaml.load(open(config, 'r'), Loader=yaml.Loader)
exp_parent_dir = os.path.join(cfg['trainer']['log_dir'], os.path.basename(config).split(".")[0])
cfg['trainer']['log_dir'] = exp_parent_dir
logger_dir = os.path.join(exp_parent_dir, "log")
os.makedirs(exp_parent_dir, exist_ok=True)
os.makedirs(logger_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger = create_logger(os.path.join(logger_dir, timestamp))pretty = pretty_print('conf', cfg)
logging.info(pretty)
# init torch
init_torch(rng_seed= cfg['random_seed']-3, cuda_seed= cfg['random_seed'])##w,h,l
# Ped [1.7431 0.8494 0.911 ]
# Car [1.8032 2.1036 4.8104]
# Cyc [1.7336 0.823 1.753 ]
# Sign [0.6523 0.6208 0.1254]
cls_mean_size = np.array([[1.7431, 0.8494, 0.9110],[1.8032, 2.1036, 4.8104],[1.7336, 0.8230, 1.7530],[0.6523, 0.6208, 0.1254]])# build model
model = build_model(cfg,cls_mean_size)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Knew=np.array([[7.070493000000e+02, 0.000000000000e+00, 6.040814000000e+02 ],
# [0.000000000000e+00, 7.070493000000e+02, 1.805066000000e+02 ], [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00 ]])
# calibs = np.array([7.070493000000e+02/1920.0*768.0*0.9, 0., (6.040814000000e+023)/1920.0*768.0, 0.,
# 0., 7.070493000000e+02/1080.0*512.0*0.9, 1.805066000000e+02/1080.0*512.0, 0.,
# 0.0, 0.0, 1.0, 0.], dtype=np.float32)calibs = np.array([9.5749393532104671e+02/1920.0*768.0*0.9, 0., (1.0143950223349893e+03)/1920.0*768.0, 0.,
0., 9.5697913744394737e+02/1080.0*512.0*0.9, 5.4154074349276050e+02/1080.0*512.0, 0.,
0.0, 0.0, 1.0, 0.], dtype=np.float32)
calibsP2=calibs.reshape(3, 4)downsample = 4
# c++调用的接口
def detect_online(frame_chan, frame):print("image channel :", frame_chan)resolution = np.array([768, 512])img = frameindex = np.array([frame_chan])#img = img[0:1080, int(150-1):int(1920-150-1)]# img = cv2.resize(img, (768, 512))img_cv_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = Image.fromarray(img_cv_rgb)img_size = np.array(img.size)center = img_size / 2crop_size = img_size# coord_ranges = np.array([center-crop_size/2,center+crop_size/2]).astype(np.float32)coord_ranges = np.array([np.array([0, 0,]),resolution]).astype(np.float32)features_size = resolution // downsampleinfo = {'img_id': index,'img_size': resolution,'bbox_downsample_ratio': resolution / features_size}# add affine transformation for 2d images.trans, trans_inv = get_affine_transform(center, crop_size, 0, resolution, inv=1)img = img.transform(tuple(resolution.tolist()),method=Image.AFFINE,data=tuple(trans_inv.reshape(-1).tolist()),resample=Image.BILINEAR)# image encodingimg = np.array(img).astype(np.float32) / 255.0mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)std = np.array([0.229, 0.224, 0.225], dtype=np.float32)img = (img - mean) / stdimg = img.transpose(2, 0, 1) # C * H * Winputs = imgload_checkpoint(model = model,optimizer = None,filename = cfg['tester']['resume_model'],logger=logger,map_location= device)model.to(device)torch.set_grad_enabled(False)model.eval()inputs = torch.from_numpy(inputs)inputs = inputs.unsqueeze(0) calibs = torch.from_numpy(calibsP2)calibs = calibs.unsqueeze(0) coord_ranges = torch.from_numpy(coord_ranges)coord_ranges = coord_ranges.unsqueeze(0) inputs = inputs.to(device)calibs = calibs.to(device)coord_ranges = coord_ranges.to(device)# the outputs of centernetoutputs = model(inputs,coord_ranges,calibs,K=50,mode='test')dets = extract_dets_from_outputs(outputs=outputs, K=50)dets = dets.detach().cpu().numpy()# get corresponding calibs & transform tensor to numpy#info = {key: val.detach().cpu().numpy() for key, val in info.items()}calibs = Calibration(calibsP2)dets = decode_detections(dets = dets,info = info,calibs = calibs,cls_mean_size=cls_mean_size,threshold = cfg['tester']['threshold'])img = cv2.imread(img_file, cv2.IMREAD_COLOR)# img = cv2.resize(img, (768, 512))print(dets)for i in range(len(dets)):#cv2.rectangle(img, (int(dets[i][2]), int(dets[i][3])), (int(dets[i][4]), int(dets[i][5])), (0, 0, 255), 1)corners_2d,corners_3d=compute_box_3d(dets[i][6], dets[i][7], dets[i][8], dets[i][12],(dets[i][9], dets[i][10], dets[i][11]), calibs)for i in range(corners_2d.shape[0]):corners_2d[i] = affine_transform(corners_2d[i], trans_inv)cv2.line(img,np2tuple(corners_2d[0]),np2tuple(corners_2d[1]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[1]),np2tuple(corners_2d[2]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[2]),np2tuple(corners_2d[3]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[3]),np2tuple(corners_2d[0]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[4]),np2tuple(corners_2d[5]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[5]),np2tuple(corners_2d[6]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[6]),np2tuple(corners_2d[7]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[7]),np2tuple(corners_2d[4]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[5]),np2tuple(corners_2d[1]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[4]),np2tuple(corners_2d[0]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[6]),np2tuple(corners_2d[2]),(255,255,0),1,cv2.LINE_4)cv2.line(img,np2tuple(corners_2d[7]),np2tuple(corners_2d[3]),(255,255,0),1,cv2.LINE_4)cv2.imwrite("3.jpg", img)if __name__ == '__main__':img_file = "/media/data/zs/project/detection3D/github/DEVIANT/code/tools/img/result2_00001225.jpg"# img_file = "./151608731.jpg"frame = cv2.imread(img_file, cv2.IMREAD_COLOR)frame_chan = 0detect_online(frame_chan, frame)
源码解析
def decode_detections(dets, info, calibs, cls_mean_size, threshold):'''NOTE: THIS IS A NUMPY FUNCTIONinput: dets, numpy array, shape in [batch x max_dets x dim]input: img_info, dict, necessary information of input imagesinput: calibs, corresponding calibs for the input batchoutput:'''results = []for i in range(dets.shape[0]): # batchpreds = []for j in range(dets.shape[1]): # max_detscls_id = int(dets[i, j, 0])score = dets[i, j, 1]if score < threshold: continue# 2d bboxs decodingx = dets[i, j, 2] * info['bbox_downsample_ratio'][0]y = dets[i, j, 3] * info['bbox_downsample_ratio'][1]w = dets[i, j, 4] * info['bbox_downsample_ratio'][0]h = dets[i, j, 5] * info['bbox_downsample_ratio'][1]bbox = [x-w/2, y-h/2, x+w/2, y+h/2]# 3d bboxs decoding# depth decodingdepth = dets[i, j, 6]# heading angle decodingalpha = get_heading_angle(dets[i, j, 7:31])ry = calibs.alpha2ry(alpha, x) #alpha、 ry 可以参考kitti数据集,都是表示目标角度# dimensions decodingdimensions = dets[i, j, 31:34]dimensions += cls_mean_size[int(cls_id)]if True in (dimensions<0.0): continue# positions decodingx3d = dets[i, j, 34] * info['bbox_downsample_ratio'][0]y3d = dets[i, j, 35] * info['bbox_downsample_ratio'][1]locations = calibs.img_to_rect(x3d, y3d, depth).reshape(-1)#这一步的操作是:将图像坐标系上的3D目标中心点(x,y),通过相机内外参和深度值,得出相机坐标系下的3D目标中心点(x,y),既相机坐标系下,真实3维世界目标中心点(x,y)preds.append([cls_id, alpha] + bbox + dimensions.tolist() + locations.tolist() + [ry, score])results = predsreturn results#two stage style
def extract_dets_from_outputs(outputs, K=50):# get src outputsheatmap = outputs['heatmap'] #shape = [1,3,128,192],heatmap 是导火索,引出目标类别和位置,3是目标的类别个数。size_2d = outputs['size_2d'] #shape = [1,2,128,192] ,size_2d image上面检测目标大小offset_2d = outputs['offset_2d'] #shape = [1,2,128,192],offset_2d image上面检测目标中心的偏移量batch, channel, height, width = heatmap.size() # get shapeheading = outputs['heading'].view(batch,K,-1) #shape = [1,50,24] #相机坐标系上面检测目标角度depth = outputs['depth'].view(batch,K,-1)[:,:,0:1] #shape = [1,50,1] #相机坐标系上面检测目标深度,也就是z值。size_3d = outputs['size_3d'].view(batch,K,-1) #shape = [1,50,3] #相机坐标系上面检测目标尺寸大小offset_3d = outputs['offset_3d'].view(batch,K,-1) #shape = [1,50,2] #image坐标系上面检测目标中心点偏移heatmap= torch.clamp(heatmap.sigmoid_(), min=1e-4, max=1 - 1e-4)# perform nms on heatmapsheatmap = _nms(heatmap) #shape = [1,3,128,192]scores, inds, cls_ids, xs, ys = _topk(heatmap, K=K) #scores, inds, cls_ids, xs, ys 的shape都是 = [1,50] ,这里有两个topk,第一个topk是在每3个类别上分别都选取50个目标,共有3*50=150个目标,第二个topk在150个目标上进行,最终所有类别一共得到50个目标,。offset_2d = _transpose_and_gather_feat(offset_2d, inds) #shape = [1,50,2],获取image参考系上2D目标的尺寸偏移量。offset_2d = offset_2d.view(batch, K, 2) #shape = [1,50,2]xs2d = xs.view(batch, K, 1) + offset_2d[:, :, 0:1] #shape = [1,50,1],获取image参考系上2D目标的x。ys2d = ys.view(batch, K, 1) + offset_2d[:, :, 1:2] #shape = [1,50,1],获取image参考系上2D目标的y。xs3d = xs.view(batch, K, 1) + offset_3d[:, :, 0:1] #shape = [1,50,1],获取3D目标的中心点在image参考系上的x值。ys3d = ys.view(batch, K, 1) + offset_3d[:, :, 1:2] #shape = [1,50,1],获取3D目标的中心点在image参考系上的y值。cls_ids = cls_ids.view(batch, K, 1).float() #shape = [1,50,1]depth_score = (-(0.5*outputs['depth'].view(batch,K,-1)[:,:,1:2]).exp()).exp() #shape = [1,50,1]scores = scores.view(batch, K, 1)*depth_score #shape = [1,50,1]# check shapexs2d = xs2d.view(batch, K, 1)ys2d = ys2d.view(batch, K, 1)xs3d = xs3d.view(batch, K, 1)ys3d = ys3d.view(batch, K, 1)size_2d = _transpose_and_gather_feat(size_2d, inds)size_2d = size_2d.view(batch, K, 2) #shape = [1,50,2]detections = torch.cat([cls_ids, scores, xs2d, ys2d, size_2d, depth, heading, size_3d, xs3d, ys3d], dim=2)return detections############### auxiliary function ############def _nms(heatmap, kernel=3):padding = (kernel - 1) // 2heatmapmax = nn.functional.max_pool2d(heatmap, (kernel, kernel), stride=1, padding=padding)keep = (heatmapmax == heatmap).float()return heatmap * keepdef _topk(heatmap, K=50):batch, cat, height, width = heatmap.size()# batch * cls_ids * 50topk_scores, topk_inds = torch.topk(heatmap.view(batch, cat, -1), K)topk_inds = topk_inds % (height * width)topk_ys = (topk_inds / width).int().float()topk_xs = (topk_inds % width).int().float()# batch * cls_ids * 50topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)topk_cls_ids = (topk_ind / K).int()topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)return topk_score, topk_inds, topk_cls_ids, topk_xs, topk_ysdef _gather_feat(feat, ind, mask=None):'''Args:feat: tensor shaped in B * (H*W) * Cind: tensor shaped in B * K (default: 50)mask: tensor shaped in B * K (default: 50)Returns: tensor shaped in B * K or B * sum(mask)'''dim = feat.size(2) # get channel dimind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) # B*len(ind) --> B*len(ind)*1 --> B*len(ind)*Cfeat = feat.gather(1, ind) # B*(HW)*C ---> B*K*Cif mask is not None:mask = mask.unsqueeze(2).expand_as(feat) # B*50 ---> B*K*1 --> B*K*Cfeat = feat[mask]feat = feat.view(-1, dim)return feat