PointNet2(一)分类

news/2024/9/19 12:11:59/ 标签: 分类, 数据挖掘, 人工智能

        发现PVN3D中使用到了pointnet2和 densfusion等网络,为了看懂pvn3d,因此得看看pointnet2,然而带cpp,cu文件的程序一时办事编译不成功,因此找到了一个 Pointnet_Pointnet2_pytorch-master,里面有pointnet和pointnet2网络,在这个程序中学习pointnet2.

首先看Pointnet2_utils.py文件

1.pc_normalize函数

        这一个是点归一化,就是找到点集合得中心,然后每个点都减去这个中心,然后计算x*x+y*y+z*z再开根号,得到距离。使用max计算最大距离,然后让每一个重心化之后的点除上这个距离,就得到了归一化得坐标。

def pc_normalize(pc):l = pc.shape[0]   #点的个数centroid = np.mean(pc, axis=0)  #点的中心pc = pc - centroid  #重心化m = np.max(np.sqrt(np.sum(pc**2, axis=1)))  #计算得到点里中心点的最大距离pc = pc / m  #除上最大距离return pc

2.square_distance   

M个点和N个点之间的距离构成了N*M矩阵, 每一个元素(i,j)中的值存储的都是,点集合N中的第i个点和  点集合M中的第j个点之间的距离。


#计算距离矩阵 比如src是 B组N个点  dst是B组M个点, 则最后得到的距离矩阵是 B组N行M列
def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape  #B组M个点_, M, _ = dst.shape  #B组B个点dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) #B组 N行M列矩阵,每一个元素都是-2*i*jdist += torch.sum(src ** 2, -1).view(B, N, 1)  #平方项,使用了广播特性dist += torch.sum(dst ** 2, -1).view(B, 1, M)  #平方项,使用了广播特性return dist

3.index_points

根据点索引提取点坐标或者点属性特征矢量


#根据idx 索引得到点位置或者属性
def index_points(points, idx):"""Input:points: input points data, [B, N, C]idx: sample index data, [B, S]Return:new_points:, indexed points data, [B, S, C]"""device = points.device  #点所在的设备B = points.shape[0]  #B组view_shape = list(idx.shape)  #view_shape[1:] = [1] * (len(view_shape) - 1)repeat_shape = list(idx.shape)repeat_shape[0] = 1batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)new_points = points[batch_indices, idx, :]return new_points

4.最远点采样

    算法原理比较简单,首先随机生成B*1的farthest,作为centroids的第一次选择的点索引,然后

使用xyz[batch_indices, farthest, :].view(B, 1, 3) 把这些索引到的点给提取出来,然后让每一组中的点都减去对应组的索引点,再计算距离,再第0次迭代中,由于计算出来的点点距离都小于1e10,因此distance都被更新了。

        这时候选择出来最大距离点,就是第二个最远点。选择该最远点的索引作为第1次迭代(注意从第0次迭代开始),然后还是与每一个点进行比较,计算距离,如果距离小于distance中的距离,则让该距离小的值填充distance中对应位置,最后,在选择该轮中distance中的最大值所对应的索引作为 farthest。

相当于每一组一个distanceK,每一次都用距离最小值更新distanceK, 等同于选择与 最远点集中所有点 距离最小的点中的 距离最大点,有点绕口,详细来说,分为三步:

1)让剩余点集中的一个点,与最远点集中的所有点计算距离,选择最小距离

2)遍历1),得到剩余点集中每一个点 到 最远点集中的 最小距离

3)在所有最小距离中,选择最大距离所对应的点


def farthest_point_sample(xyz, npoint):"""Input:xyz: pointcloud data, [B, N, 3]npoint: number of samplesReturn:centroids: sampled pointcloud index, [B, npoint]"""device = xyz.device #设备B, N, C = xyz.shape  #batch  n个点  ccentroids = torch.zeros(B, npoint, dtype=torch.long).to(device)distance = torch.ones(B, N).to(device) * 1e10  #B*N矩阵,farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # 随机生成B个 0-N中的数值batch_indices = torch.arange(B, dtype=torch.long).to(device) #B个索引值for i in range(npoint):centroids[:, i] = farthest # 第0次是 随机的索引,centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) #batch_indices = :, farthest, centroid得到中心点dist = torch.sum((xyz - centroid) ** 2, -1) # 减去 farthest 所对应的点,得到距离值mask = dist < distance  #生成掩膜distance[mask] = dist[mask] #把dist中比distance中对应位置更小的距离,更新到distance中farthest = torch.max(distance, -1)[1] #最大距离值的索引return centroids

5.query_ball_point

功能:让group_idx 为N的位置填充上 每一组中每一行第一个元素值,相当于是小于nsample个点的时候,就填充第一个点的索引

理解:最远采样点为中心,得到小于中心+radius 区域内的点的索引,如果不够,则使用第一个查询的点的索引进行填充。最终每一个组中,每一个点,都采样到了相同个数(nsample个)的点,

def query_ball_point(radius, nsample, xyz, new_xyz):"""Input:radius: local region radiusnsample: max sample number in local regionxyz: all points, [B, N, 3]new_xyz: query points, [B, S, 3]Return:group_idx: grouped points index, [B, S, nsample]"""device = xyz.deviceB, N, C = xyz.shape_, S, _ = new_xyz.shapegroup_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) #group_idx =[B,S,N]sqrdists = square_distance(new_xyz, xyz) #得到new_xyz和xyz之间的距离矩阵group_idx[sqrdists > radius ** 2] = N #距离矩阵中的值大于查询半径的时候,,索引值设置为N,group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] #按照行进行排序(相当于对一行中的所有列进行排序),只选取前nsample个group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])  #让每一组中每一行第一个元素值,填充group_firstmask = group_idx == N #制作一个掩膜group_idx[mask] = group_first[mask] #让group_idx 为N的位置填充上 每一组中每一行第一个元素值,相当于是小于nsample个点的时候,就填充第一个点的索引return group_idx #返回索引

6.sample_and_group

这个函数的活,除了最远点采样和聚组 这两个函数,剩下的就是一个根据索引计算属性的函数了,然后把点位置和点属性进行特征连接,得到心得特征。

def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):"""Input:npoint:radius:nsample:xyz: input points position data, [B, N, 3]points: input points data, [B, N, D]Return:new_xyz: sampled points position data, [B, npoint, nsample, 3]new_points: sampled points data, [B, npoint, nsample, 3+D]"""B, N, C = xyz.shapeS = npointfps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]new_xyz = index_points(xyz, fps_idx)  #得到最远点采样的点集合idx = query_ball_point(radius, nsample, xyz, new_xyz)grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] #得到球查询的点的组grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) #使用最远采样点作为中心点,进行重心化if points is not None:grouped_points = index_points(points, idx)  #对属性/特征进行采样new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 连接点位置(3) 和 点特征else:new_points = grouped_xyz_norm #只要点位置,没有点属性if returnfps:return new_xyz, new_points, grouped_xyz, fps_idxelse:return new_xyz, new_points #返回了最远点采样中心点,和生成的特征

7.sample_and_group_all

传进来的点和属性全都用,不做采样了,然后如果有属性,就把位置和属性连接起来构成新的属性,如果没有属性,则只使用位置。


def sample_and_group_all(xyz, points):"""Input:xyz: input points position data, [B, N, 3]points: input points data, [B, N, D]Return:new_xyz: sampled points position data, [B, 1, 3]new_points: sampled points data, [B, 1, N, 3+D]"""device = xyz.deviceB, N, C = xyz.shapenew_xyz = torch.zeros(B, 1, C).to(device)grouped_xyz = xyz.view(B, 1, N, C)if points is not None:new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)else:new_points = grouped_xyzreturn new_xyz, new_points

8.PointNetSetAbstractionMsg

# PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])

看这一行参数,表示,最远点采样采集512个点作为中心点,开始聚组,球半径分别为,0.1,0.2,0.3,在每一个半径中选择16个点,32个点和128个点。in_chanenl在第一次时候为3(位置),为6(带法向)。分别生成了:

B*512*16*3    通道为3,最远点采样了512个点,每个点组了16个点,每个点是3维。

B*512*32*3

B*512*128*3

三个张量。

第一个SA:

然后B*512*16*3    通道为3,[32, 32, 64]分别relu(bn(conv()))三件套,三角套第一次生成 B*32* 512*16,max后把组点个数给销毁,变成B*32* 512; 在第一次基础上第二次生成B*32* 512*16,,max后把组点个数给销毁,变成B*32* 512; 在第二次基础上第三次生成B*64* 512*16,,max后把组点个数给销毁,变成B*64* 512. 。

new_xyz, 为512个点
new_points  选中列表最后一个数值,[32, 32, 64], [64, 64, 128], [64, 96, 128]], 因此结果为:B*(64+ 128 + 128)* 512 ,为 B* 320* 512.

也就是第一次SA后得到new_xyz(512个点),new_points (B* 320* 512)

第二个SA:

self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])

同理,是在512个点中采集128个点,这一次特征矢量是320维度+3位置=323,最后计算结果就是

B*(128+ 256+ 256)* 512 ,为 B* 640* 512.

第三个SA:

PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)

显然,三次三件套后就是 B* 1024* 128,128是第二次SA中的最远采样点个数,然后max一下变成了B* 1024矢量


# PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
class PointNetSetAbstraction(nn.Module):def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):super(PointNetSetAbstraction, self).__init__()self.npoint = npointself.radius = radiusself.nsample = nsampleself.mlp_convs = nn.ModuleList()self.mlp_bns = nn.ModuleList()last_channel = in_channelfor out_channel in mlp:self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))self.mlp_bns.append(nn.BatchNorm2d(out_channel))last_channel = out_channelself.group_all = group_alldef forward(self, xyz, points):"""Input:xyz: input points position data, [B, C, N]points: input points data, [B, D, N]Return:new_xyz: sampled points position data, [B, C, S]new_points_concat: sample points feature data, [B, D', S]"""xyz = xyz.permute(0, 2, 1)if points is not None:points = points.permute(0, 2, 1)if self.group_all:new_xyz, new_points = sample_and_group_all(xyz, points)else:new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)# new_xyz: sampled points position data, [B, npoint, C]# new_points: sampled points data, [B, npoint, nsample, C+D]new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]for i, conv in enumerate(self.mlp_convs):bn = self.mlp_bns[i]new_points =  F.relu(bn(conv(new_points)))new_points = torch.max(new_points, 2)[0]new_xyz = new_xyz.permute(0, 2, 1)return new_xyz, new_points

9.get_model

最后一次PointNetSetAbstraction,生成了一个表示该物体的一个特征向量。该向量被送到全连接层,最后一层就是分类个数(40或10)了。

B* 1024矢量经过,1024--->512---》256---》num_class(10或者40),就得到了类别输出,这时候和标签真值使用softmax作为loss函数,计算得到误差。


class get_model(nn.Module):def __init__(self,num_class,normal_channel=True):super(get_model, self).__init__()in_channel = 3 if normal_channel else 0self.normal_channel = normal_channelself.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)self.fc1 = nn.Linear(1024, 512)self.bn1 = nn.BatchNorm1d(512)self.drop1 = nn.Dropout(0.4)self.fc2 = nn.Linear(512, 256)self.bn2 = nn.BatchNorm1d(256)self.drop2 = nn.Dropout(0.5)self.fc3 = nn.Linear(256, num_class)def forward(self, xyz):B, _, _ = xyz.shapeif self.normal_channel:norm = xyz[:, 3:, :]xyz = xyz[:, :3, :]else:norm = Nonel1_xyz, l1_points = self.sa1(xyz, norm)l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)x = l3_points.view(B, 1024)x = self.drop1(F.relu(self.bn1(self.fc1(x))))x = self.drop2(F.relu(self.bn2(self.fc2(x))))x = self.fc3(x)x = F.log_softmax(x, -1)return x,l3_points


http://www.ppmy.cn/news/1527634.html

相关文章

孙怡带你深度学习(2)--PyTorch框架认识

文章目录 PyTorch框架认识1. Tensor张量定义与特性创建方式 2. 下载数据集下载测试展现下载内容 3. 创建DataLoader&#xff08;数据加载器&#xff09;4. 选择处理器5. 神经网络模型构建模型 6. 训练数据训练集数据测试集数据 7. 提高模型学习率 总结 PyTorch框架认识 PyTorc…

java-springboot 实现文件 图片的上传 以及渲染

在 Java Spring Boot 应用中实现文件和图片的上传以及渲染&#xff0c;通常涉及以下几个步骤&#xff1a; 配置文件上传&#xff1a;使用 Spring Boot 的 MultipartResolver 来配置文件上传。 创建上传接口&#xff1a;创建一个 REST 控制器来处理上传请求。 保存文件到服务器&…

C#基础(14)冒泡排序

前言 其实到上一节结构体我们就已经将c#的基础知识点大概讲完&#xff0c;接下来我们会讲解一些关于算法相关的东西。 我们一样来问一下gpt吧&#xff1a; Q:解释算法 A: 算法是一组有序的逻辑步骤&#xff0c;用于解决特定问题或执行特定任务。它可以是一个计算过程、一个…

linux-Shell 编程-常用 Shell 脚本技巧

Linux Shell 编程&#xff1a;常用 Shell 脚本技巧 一、概述 Shell 脚本是 Linux 系统管理员和开发人员日常自动化任务的重要工具。通过编写 Shell 脚本&#xff0c;用户可以自动化重复性工作、简化系统维护、管理服务器资源等。Shell 脚本的强大之处在于其简洁和灵活性&…

手势识别-Yolov5模型-自制数据集训练

1、源码下载&#xff1a; 大家可以直接在浏览器搜索yolov5即可找到官方链接&#xff0c;跳转进github进行下载&#xff1a; 这里对yolov5模型补充说明一下&#xff0c;它是存在较多版本的&#xff0c;具体信息可在master->tags中查看&#xff0c;大家根据需要下载。这些不同…

Golang如何优雅的退出程序

Golang如何优雅的退出程序 在 Go 中优雅地退出程序&#xff0c;通常需要处理一些清理工作&#xff0c;如关闭文件、网络连接、释放资源等。以下是一些常见的方法&#xff1a; 一、使用 os.Signal 和 signal.Notify 捕获系统信号&#xff1a;可以使用 os/signal 包来捕获中断…

Android中如何处理运行时权限?

在Android中&#xff0c;处理运行时权限是开发过程中一个至关重要的环节&#xff0c;它自Android 6.0&#xff08;API级别23&#xff09;引入&#xff0c;旨在提高用户隐私保护和应用的透明度。以下将详细阐述Android中处理运行时权限的方法、步骤、注意事项以及相关的最佳实践…

最优化理论与自动驾驶(十一):基于iLQR的自动驾驶轨迹跟踪算法(c++和python版本)

最优化理论与自动驾驶&#xff08;四&#xff09;&#xff1a;iLQR原理、公式及代码演示 之前的章节我们介绍过&#xff0c;iLQR&#xff08;迭代线性二次调节器&#xff09;是一种用于求解非线性系统最优控制最优控制最优控制和规划问题的算法。本章节介绍采用iLQR算法对设定…

使用阿里OCR身份证识别

1、开通服务 免费试用 2、获取accesskay AccessKeyId和AccessKeySecret 要同时复制保存下来 因为后面好像看不AccessKeySecret了 3.Api 参考 https://help.aliyun.com/zh/ocr/developer-reference/api-ocr-api-2021-07-07-recognizeidcard?spma2c4g.11186623.0.0.7a9f4b1e5C…

STM32快速复习(十二)FLASH闪存的读写

文章目录 一、FLASH是什么&#xff1f;FLASH的结构&#xff1f;二、使用步骤1.标准库函数2.示例函数 总结 一、FLASH是什么&#xff1f;FLASH的结构&#xff1f; 1、FLASH简介 &#xff08;1&#xff09;STM32F1系列的FLASH包含程序存储器、系统存储器和选项字节三个部分&…

AcWing算法基础课-789数的范围-Java题解

大家好&#xff0c;我是何未来&#xff0c;本篇文章给大家讲解《AcWing算法基础课》789 题——数的范围。本文详细解析了一个基于二分查找的算法题&#xff0c;题目要求在有序数组中查找特定元素的首次和最后一次出现的位置。通过使用两个二分查找函数&#xff0c;程序能够高效…

随机规划及其MATLAB实现

目录 引言 随机规划的基本模型 随机动态规划 随机动态规划建模实例​(随机动态规划)&#xff1a; MATLAB中的随机规划实现 示例&#xff1a;两阶段随机规划 表格总结&#xff1a;随机规划求解方法与适用场景 结论 引言 随机规划&#xff08;Stochastic Programming&…

VulhubDC-4靶机详解

项目地址 https://download.vulnhub.com/dc/DC-4.zip实验过程 将下载好的靶机导入到VMware中&#xff0c;设置网络模式为NAT模式&#xff0c;然后开启靶机虚拟机 使用nmap进行主机发现&#xff0c;获取靶机IP地址 nmap 192.168.47.1-254根据对比可知DC-4的一个ip地址为192.1…

C++——多态的原理

多态的原理 多态的原理引入虚函数表 多态的原理 引入 如下代码的输出结果为&#xff08;&#xff09; A.编译报错 B.运行报错 C.8 D.12 上⾯题⽬运⾏结果12bytes&#xff0c;除了_b和_ch成员&#xff0c;还多⼀个__vfptr放对象的前⾯(注意有些平台可能会放到对象的最后⾯&am…

web基础—dvwa靶场(七)SQL Injection

SQL Injection&#xff08;SQL注入&#xff09; SQL Injection&#xff08;SQL注入&#xff09;&#xff0c;是指攻击者通过注入恶意的SQL命令&#xff0c;破坏SQL查询语句的结构&#xff0c;从而达到执行恶意SQL语句的目的。SQL注入漏洞的危害是巨大的&#xff0c;常常会导致…

AI绘画:科技赋能艺术的崭新时代

&#x1f4af;AI绘画&#xff1a;走进艺术创新的新时代 人工智能在改变世界的过程中&#xff0c;AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景&#xff0c;并为您推荐几款出色的AI绘画工具&#xff0c;助您领略这一技术带来的艺术新体…

git bash中执行java命令乱码问题处理

Git Bash中执行java命令显示乱码 购机自带windows字符集为gbk&#xff0c;git bash默认为utf8&#xff0c;导致中文字符显示乱码 处理方案如下 顶部右键点击Options 选择Text&#xff0c;更换字符集即可

彩蛋岛 销冠大模型案例

彩蛋岛 销冠大模型案例 任务&#xff1a; https://kkgithub.com/InternLM/Tutorial/tree/camp3/docs/EasterEgg/StreamerSales 视频 https://www.bilibili.com/video/BV1f1421b7Du/?vd_source4ffecd6d839338c9390829e56a43ca8d 项目git地址&#xff1a; https://kkgithu…

VideoPlayer插件的用法

文章目录 1. 概念介绍2. 使用方法2.1 实现步骤2.2 具体细节 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何获取文件类型"相关的内容&#xff0c;本章回中将介绍如何播放视频.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 播放视频是我们常用…

[深度学习]Pytorch框架

1 深度学习简介 应用领域&#xff1a;语音交互、文本处理、计算机视觉、深度学习、人机交互、知识图谱、分析处理、问题求解 2 发展历史 1956年人工智能元年2016年国内开始关注深度学习2017年出现Transformer框架2018年Bert和GPT出现2022年&#xff0c;chatGPT出现&#xff0…