1. GT
gt做在一张np.zeros上就行, 1代表正样本,0是背景, 算Loss时进交叉熵会自动编码
2. 模型端设计
模型的激活函数, 分类分支选sigmoid或softmax ; 如果选softmax, 分类分支 就是输出2通道,通道0默认为背景,通道1为前景
3. 推理脚本
def export_descriptor():"""# input 2 images, output keypoints and correspondencesave prediction:pred:'image': np(320,240)'prob' (keypoints): np (N1, 2)'desc': np (N2, 256)'warped_image': np(320,240)'warped_prob' (keypoints): np (N2, 2)'warped_desc': np (N2, 256)'homography': np (3,3)'matches': np [N3, 4]"""# basic settingsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")logging.info("train on device: %s", device)# paramsnms_radius = 2confidence_threshold = 0.4 # 0.55max_keypoints = 2000## load pretrainedweights_path = '/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/superPointNet_16000_focal_checkpoint.pth.tar'val_agent = SuperPointNet_unet()# val_agent = SuperPointNet_unet_vgg()val_agent.load_state_dict(torch.load(weights_path, map_location='cuda')['model_state_dict'])# val_agent.load_state_dict(torch.load(weights_path)['model_state_dict'])val_agent.eval()val_agent.cuda()# #---------dataloader------save_imgpath = Path('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/output/')base_path = Path('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/data')folder_paths = [x for x in base_path.iterdir() if x.is_dir()] for file_dir in folder_paths:dir = str(file_dir).split('/')[-1]save_dir = os.path.join(save_imgpath, dir)if not os.path.exists(save_dir):os.makedirs(save_dir)# img_list = sorted(glob.glob(os.path.join(file_dir, "*.jpg")))img_list = sorted(glob.glob(os.path.join('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/SuperPointExtractor/data/images/', "*.jpg")))for i, sample in tqdm(enumerate(img_list)):img_name = sampleinput = _preprocess111(img_name,(944,1824))_,_,H,W = input.shapewith torch.no_grad():start = time.time()outs = val_agent.forward(input.to(device))location, descriptors = outs['semi'], outs['desc'] # (b,2,H,W) 此时location还未做softmaxlocation = torch.nn.functional.softmax(location, dim = 1) location = location[:,1:,:,:] # 取1通道 代表特征点的通道 # (1,1,H,W)# end = time.time()# running_time = end-start# print('time cost : %.5f sec' %running_time)# NMSlocation_pooling = torch.nn.functional.max_pool2d(location, kernel_size = nms_radius * 2 + 1, stride = 1,padding = nms_radius) # (1,1,H,W)location = location.squeeze()location_pooling = location_pooling.squeeze()end1 = time.time()running_time = end1-start# print('time cost : %.5f sec' %running_time)# xs, ys = torch.where(location > confidence_threshold)xs, ys = torch.where((location > confidence_threshold) & (location == location_pooling)) # 这里的xs是h方向, ys是w方向print('---xs--',xs.shape)end2 = time.time()running_time1 = end2-end1print('time cost : %.5f sec' %running_time1)#=========KL=============# coord0 = torch.cat((xs.unsqueeze(1), ys.unsqueeze(1)), dim = 1)# loc0_at_keypoints = outs['semi'][0, :, coord0[:,0], coord0[:,1]].transpose(0,1) # loc0_prob = F.softmax(loc0_at_keypoints, dim = 1) # loc0_prob_target = torch.log(loc0_prob+1e-8) # tt = F.kl_div(torch.log(loc0_prob+1e-8),loc0_prob_target, reduction='none') #======================== # No keypoints detected.if xs.shape[0] == 0:return np.zeros([3, 0]), np.zeros([128, 0])# Sort keypoints.indicates = location[xs, ys].sort(descending=True)[1] # 将N个点的score值从大到小排列,返回这些点排序后在xs,ys中的索引if indicates.shape[0] > max_keypoints:indicates = indicates[:max_keypoints]xs = xs[indicates]ys = ys[indicates]keypoints = torch.cat((ys.to(location.dtype), xs.to(location.dtype),location[xs.to(torch.long), ys.to(torch.long)])).reshape([3, -1])# keypoints = keypoints.detach().cpu().numpy()print('--keypoints--',keypoints.shape)# -----------------------------mask_chegai_path = img_name.replace('images','mask').replace('.jpg','.png')mask_chegai_img = cv2.imread(mask_chegai_path) # (H,W,3)mask_chegai_img = torch.from_numpy(mask_chegai_img).cuda()pts_filter= filter_obj_pts(mask_chegai_img[:,:,0], torch.cat((ys.to(location.dtype), xs.to(location.dtype))).reshape([2, -1]), value=255) # (2,N) // 实例26#-------------------# 先在原始图像中画出检测出来的点mask = np.zeros((H,W,3))img_show = input.detach().cpu().numpy().squeeze()out1 = (np.dstack((img_show, img_show, img_show))* 255.).astype('uint8')basename = os.path.basename(img_name).replace('.jpg','.png')save_path = save_dir + '/'+ basenamefor y , x in zip(pts_filter[1,:], pts_filter[0,:]):# for y , x in zip(xs, ys):cv2.circle(out1, ( int(x),int(y)), radius=1, color=(255,0,255), thickness=1)# path = '/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/pytorch-superpoint-quantu/5_2_1st_10frame'+ str(i)+ '.jpg'cv2.imwrite(save_path, out1)