二分类gt-pred-loss

server/2024/9/23 4:16:57/

1. GT
gt做在一张np.zeros上就行, 1代表正样本,0是背景, 算Loss时进交叉熵会自动编码

2. 模型端设计
模型的激活函数, 分类分支选sigmoid或softmax ; 如果选softmax, 分类分支 就是输出2通道,通道0默认为背景,通道1为前景

3. 推理脚本


def export_descriptor():"""# input 2 images, output keypoints and correspondencesave prediction:pred:'image': np(320,240)'prob' (keypoints): np (N1, 2)'desc': np (N2, 256)'warped_image': np(320,240)'warped_prob' (keypoints): np (N2, 2)'warped_desc': np (N2, 256)'homography': np (3,3)'matches': np [N3, 4]"""# basic settingsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")logging.info("train on device: %s", device)# paramsnms_radius = 2confidence_threshold = 0.4 # 0.55max_keypoints = 2000## load pretrainedweights_path = '/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/superPointNet_16000_focal_checkpoint.pth.tar'val_agent = SuperPointNet_unet()# val_agent = SuperPointNet_unet_vgg()val_agent.load_state_dict(torch.load(weights_path, map_location='cuda')['model_state_dict'])# val_agent.load_state_dict(torch.load(weights_path)['model_state_dict'])val_agent.eval()val_agent.cuda()# #---------dataloader------save_imgpath = Path('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/output/')base_path = Path('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/inference_superpoint_2/data')folder_paths = [x for x in base_path.iterdir() if x.is_dir()] for file_dir in folder_paths:dir = str(file_dir).split('/')[-1]save_dir = os.path.join(save_imgpath, dir)if not os.path.exists(save_dir):os.makedirs(save_dir)# img_list = sorted(glob.glob(os.path.join(file_dir, "*.jpg")))img_list = sorted(glob.glob(os.path.join('/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/SuperPointExtractor/data/images/', "*.jpg")))for i, sample in tqdm(enumerate(img_list)):img_name = sampleinput = _preprocess111(img_name,(944,1824))_,_,H,W = input.shapewith torch.no_grad():start = time.time()outs = val_agent.forward(input.to(device))location, descriptors = outs['semi'], outs['desc']  # (b,2,H,W)  此时location还未做softmaxlocation = torch.nn.functional.softmax(location, dim = 1) location = location[:,1:,:,:]                       # 取1通道 代表特征点的通道         # (1,1,H,W)# end  = time.time()# running_time = end-start# print('time cost : %.5f sec' %running_time)# NMSlocation_pooling = torch.nn.functional.max_pool2d(location, kernel_size = nms_radius * 2 + 1, stride = 1,padding = nms_radius)  # (1,1,H,W)location = location.squeeze()location_pooling = location_pooling.squeeze()end1  = time.time()running_time = end1-start# print('time cost : %.5f sec' %running_time)# xs, ys = torch.where(location > confidence_threshold)xs, ys = torch.where((location > confidence_threshold) & (location == location_pooling))  # 这里的xs是h方向, ys是w方向print('---xs--',xs.shape)end2  = time.time()running_time1 = end2-end1print('time cost : %.5f sec' %running_time1)#=========KL=============# coord0 = torch.cat((xs.unsqueeze(1), ys.unsqueeze(1)), dim = 1)# loc0_at_keypoints = outs['semi'][0, :, coord0[:,0], coord0[:,1]].transpose(0,1) # loc0_prob = F.softmax(loc0_at_keypoints, dim = 1)   # loc0_prob_target = torch.log(loc0_prob+1e-8)            # tt = F.kl_div(torch.log(loc0_prob+1e-8),loc0_prob_target, reduction='none')  #========================       # No keypoints detected.if xs.shape[0] == 0:return np.zeros([3, 0]), np.zeros([128, 0])# Sort keypoints.indicates = location[xs, ys].sort(descending=True)[1]  # 将N个点的score值从大到小排列,返回这些点排序后在xs,ys中的索引if indicates.shape[0] > max_keypoints:indicates = indicates[:max_keypoints]xs = xs[indicates]ys = ys[indicates]keypoints = torch.cat((ys.to(location.dtype), xs.to(location.dtype),location[xs.to(torch.long), ys.to(torch.long)])).reshape([3, -1])# keypoints = keypoints.detach().cpu().numpy()print('--keypoints--',keypoints.shape)# -----------------------------mask_chegai_path = img_name.replace('images','mask').replace('.jpg','.png')mask_chegai_img  = cv2.imread(mask_chegai_path) # (H,W,3)mask_chegai_img = torch.from_numpy(mask_chegai_img).cuda()pts_filter= filter_obj_pts(mask_chegai_img[:,:,0], torch.cat((ys.to(location.dtype), xs.to(location.dtype))).reshape([2, -1]), value=255) # (2,N) // 实例26#-------------------# 先在原始图像中画出检测出来的点mask = np.zeros((H,W,3))img_show = input.detach().cpu().numpy().squeeze()out1 = (np.dstack((img_show, img_show, img_show))* 255.).astype('uint8')basename = os.path.basename(img_name).replace('.jpg','.png')save_path = save_dir + '/'+  basenamefor y , x in zip(pts_filter[1,:], pts_filter[0,:]):# for y , x in zip(xs, ys):cv2.circle(out1, ( int(x),int(y)), radius=1, color=(255,0,255), thickness=1)# path = '/media/algo/Disk12T/Superpoint/superpoint_warp_quantu/pytorch-superpoint-quantu/5_2_1st_10frame'+ str(i)+ '.jpg'cv2.imwrite(save_path, out1)     

http://www.ppmy.cn/server/17317.html

相关文章

密码学系列1-安全规约

本篇介绍了安全性规约的概念,双线性映射,常见困难性问题(离散对数,CDH,DDH,BDH)。 一、大家初看密码方案的时候,一定迷惑于为什么论文用大篇幅进行安全性证明。为什么需要证明安全性呢? 比如一个加密方案,若定义安全性为敌手得不到完整密文,那么敌手就很有可能有能力…

【Flutter 面试题】 setState 在哪种场景下可能会失效?

【Flutter 面试题】 setState 在哪种场景下可能会失效? 文章目录 写在前面口述回答补充说明示例1:`setState` 在已销毁的Widget中使用示例2:在构建过程中调用`setState`写在前面 🙋 关于我 ,小雨青年 👉 CSDN博客专家,GitChat专栏作者,阿里云社区专家博主,51CTO专家…

聚类与分类的区别

聚类和分类是机器学习中的两个基本概念,两者的主要区别在于用于分类的数据已经预先标记好类别,而用于聚类的数据则没有预先标记的类别。以下是详细介绍: 目的不同。聚类的目的是发现数据中的自然分组,将相似或相关的对象组织在一…

Debezium系列之:Debezium2.6以上稳定版本需要注意的重要变动

Debezium系列之:Debezium2.6版本需要注意的重要变动 一、重要变动二、更新所有Connector的配置一、重要变动 最值得注意的变动包括以下内容: 所有快照模式均可用于所有连接器,但不包括仅针对 MySQL 的“never”模式。这意味着以前可能不支持快照模式(例如when_needed)的连…

JavaScript、Java、C#标记过时方法

JavaScript、Java、C#标记过时方法 在JavaScript、Java和C#中,可以使用特定的注解或标记来表示一个方法是不推荐的,以便在使用该方法时发出警告或提示。虽然没有专门用于标记不推荐方法的内置标记,但是可以结合使用deprecated、[Obsolete]等…

Mac和VScode配置fortran

最近更换了mac电脑,其中需要重新配置各类软件平台和运行环境,最近把matlab、gmt、VScode、Endnote等软件全部进行了安装和配置。但是不得不说,mac系统对于经常编程的人来说还是非常友好的! 由于需要对地震位错的程序进行编译运行…

使用Elasticsearch映射定义索引结构

在Elasticsearch中,**映射(Mapping)**是用于定义索引中文档字段的结构、类型及属性的重要组成部分。它相当于数据库表结构的设计,决定了如何对文档中的数据进行解析、存储和检索。本文将详细介绍映射的概念、支持的常规字段类型、…

FPGA ——Verilog语法示例

FPGA ——Verilog语法示例 多模块定义条件判断 多模块定义 genvar i ;generatefor (i0 ; i<8; ii1)beginxdc xdc_u(.d1 (d1 ) ,.d2 (d2 ) ,.d3 (d3 ));end endgenerate条件判断 generate beginif(DEBUG "ON")beginila ila_u(.clk(clk),.probe0({A1,A2,A3,A4}))…