图神经网络:(处理点云)PointNet++的实现

news/2025/2/19 17:40:53/

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook和有关文献。超链。提取码8848。

文章目录

    • 简单前置工作学习
    • 文献阅读
    • PointNet++的实现
    • 模型问题

简单前置工作学习

工作目标:根据点云去进行40分类。
工作流程:1.读取PyG内置的几何图形数据。2.随机但是均匀采样。3.K最邻近算法构边建图。4.使用PointNet++进行图分类。
导库,下载数据,导库,定义函数

from torch_geometric.datasets import GeometricShapes
dataset=GeometricShapes(root='/Data/GeometricShapes')
import matplotlib.pyplot as plt
def visualize_mesh(pos,face):fig=plt.figure()ax=fig.add_subplot(111,projection='3d')ax.axes.xaxis.set_ticklabels([])ax.axes.yaxis.set_ticklabels([])ax.axes.zaxis.set_ticklabels([])ax.plot_trisurf(pos[:,0],pos[:,1],pos[:,2],triangles=face.t(),antialiased=False)plt.show()

PS1:这段代码会在C盘生成一个DATA的文件并将数据集放在DATA中,有强迫症注意一下。
PS2:就是几何图形网格。细节可以点击这里。
打印信息与可视化

print(dataset)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述
导库以及定义函数

from torch_geometric.transforms import SamplePoints
import torch
def visualize_points(pos,edge_index=None,index=None):fig=plt.figure(figsize=(4, 4))if edge_index is not None:for (src,dst) in edge_index.t().tolist():src=pos[src].tolist()dst=pos[dst].tolist()plt.plot([src[0],dst[0]],[src[1],dst[1]],linewidth=1,color='black')if index is None:plt.scatter(pos[:,0],pos[:,1],s=50,zorder=1000)else:mask=torch.zeros(pos.size(0),dtype=torch.bool)mask[index]=Trueplt.scatter(pos[~mask,0],pos[~mask,1],s=50,color='lightgray',zorder=1000)plt.scatter(pos[mask,0],pos[mask,1],s=50,zorder=1000)plt.axis('off')plt.show()

从图形表面均匀地采样,打印信息与可视化

dataset.transform=SamplePoints(num=256)
data=dataset[0]
print(data)
visualize_points(data.pos)
data=dataset[4]
print(data)
visualize_points(data.pos)

jupyter notebook内输出如下
在这里插入图片描述

文献阅读

参考文献: PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

文章概述: “Deep learning on point sets for 3d classification and segmentation”是参考文献之前前沿工作,核心思想对每个点空间编码然后聚合所有单点要素到全局的空间。显然这样无法捕捉局部特征。受到卷积神经网络启发,这里参考文献便就来了。具体步骤: 第一步:进行局部划分;第二步:组合局部特征;第三步:加工局部特征,重复上述过程直到点云所有特征都被利用。所以面临三个问题。第一问:如何进行局部划分。第二问:如何组合局部特征。第三问:如何加工局部特征。解决第一问:Farthest Point Sampling,FPS。解决第二问:Ball Query。解决第三问:上面那篇文章的Point

分层的点云学习器: Sampling layer: Farthest Point Sampling,FPS。 可以使用K最近邻算法但是不好。固定一个区域更加有普适性。PS:注意一下KNN与Ball Query的区别。Grouping layer: 输入: N × ( d + C ) N \times (d+C) N×(d+C) 以及 N ′ × d N'\times d N×d 输出: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C)。符号说明: N N N是点的数量, d d d是质心坐标, C C C是点的特征维数, N ′ N' N是质心数量, K K K是邻域内点数量。Point Net layer: 输入: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C) 输出: N ′ × ( d + C ′ ) N' \times (d + C') N×(d+C) 。这个模型鲁棒性强,对于不均匀的数据效果同样。这个图挺好的。
在这里插入图片描述
PS1:原文还有其他很好工作,有兴趣有时间建议去看,但是我们这里跳过。
PS2:对于上面前置工作,由于采用是均匀的,可以这样建图。如下:
导库

from torch_cluster import knn_graph

打印信息与可视化

data=dataset[0]
data.edge_index=knn_graph(data.pos,k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)
data=dataset[4]
data.edge_index=knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)

jupyter notebook内输出如下
在这里插入图片描述

PointNet++的实现

我们使用数学公式首先进行EdgeConv的描述: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , h j ( l − 1 ) − h i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},h_j^{(l-1)}-h_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),hj(l1)hi(l1))。Point++类似于这个公式: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , p j ( l − 1 ) − p i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},p_j^{(l-1)}-p_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),pj(l1)pi(l1))
搭建多层的PointNel++

from torch_geometric.nn import MessagePassing
from torch.nn import Sequential,Linear,ReLUclass PointNetLayer(MessagePassing):def __init__(self,in_channels,out_channels):super().__init__(aggr='max')self.mlp=Sequential(Linear(in_channels+3,out_channels),ReLU(),Linear(out_channels,out_channels))def forward(self,h,pos,edge_index):return self.propagate(edge_index,h=h,pos=pos)def message(self,h_j,pos_j,pos_i):input=pos_j-pos_iif h_j is not None:input=torch.cat([h_j,input],dim=-1)return self.mlp(input)
from torch_geometric.nn import global_max_poolclass PointNet(torch.nn.Module):def __init__(self):super().__init__()self.conv1=PointNetLayer(3,32)self.conv2=PointNetLayer(32,32)self.classifier=Linear(32,dataset.num_classes)def forward(self,pos,batch):edge_index=knn_graph(pos,k=16,batch=batch,loop=True)h=self.conv1(h=pos,pos=pos,edge_index=edge_index)h=h.relu()h=self.conv2(h=h,pos=pos,edge_index=edge_index)h=h.relu()h=global_max_pool(h,batch)return self.classifier(h)model=PointNet()
print(model)
#输出如下
#PointNet(
#  (conv1): PointNetLayer()
#  (conv2): PointNetLayer()
#  (classifier): Linear(in_features=32, out_features=40, bias=True)
#)

导库,训测拆分数据变换以及划分批量

from torch_geometric.loader import DataLoader
train_dataset=GeometricShapes(root='/Data/GeometricShapes',train=True,transform=SamplePoints(128))
test_dataset=GeometricShapes(root='/Data/GeometricShapes',train=False,transform=SamplePoints(128))
train_loader=DataLoader(train_dataset,batch_size=10,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=10)

进行实验

model=PointNet();optimizer=torch.optim.Adam(model.parameters(),lr=0.01);criterion=torch.nn.CrossEntropyLoss()def train(model,optimizer,loader):model.train()total_loss=0for data in loader:optimizer.zero_grad()logits=model(data.pos,data.batch)loss=criterion(logits,data.y)loss.backward()optimizer.step()total_loss+=loss.item()*data.num_graphsreturn total_loss/len(train_loader.dataset)def test(model,loader):model.eval()total_correct=0for data in loader:logits=model(data.pos,data.batch)pred=logits.argmax(dim=-1)total_correct+=int((pred==data.y).sum())return total_correct/len(loader.dataset)for epoch in range(1,51):loss=train(model,optimizer,train_loader)test_acc=test(model,test_loader)print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
#输出如下(这里只有最后一次):
#Epoch: 50, Loss: 0.7294, Test Accuracy: 0.8250

模型问题

出现问题: 由于模型使用坐标进行输入并且选择笛卡尔坐标系传递信息所以旋转坐标就不可行。可以按照如下方式进行实验。

from torch_geometric.transforms import Compose,RandomRotate
random_rotate=Compose([RandomRotate(degrees=180,axis=0),RandomRotate(degrees=180,axis=1),RandomRotate(degrees=180,axis=2),
])
dataset=GeometricShapes(root='/DATA//GeometricShapes',transform=random_rotate)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述

transform=Compose([random_rotate,SamplePoints(num=128),
])
test_dataset=GeometricShapes(root='/DATA/GeometricShapes',train=False,transform=transform)
test_loader=DataLoader(test_dataset,batch_size=10)
test_acc=test(model,test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
#输出如下:
#Test Accuracy: 0.2000
print(len(test_dataset))
#输出如下:
#40

可以看到,模型效果,就不好了。有解决方法的。暂时就这样吧。


http://www.ppmy.cn/news/74197.html

相关文章

Ajax,前后端分离开发,前端工程化,Element,Vue路由,打包部署

Ajax介绍 Axios <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-wid…

frp+nginx+xposed搭建xp模块集群

frpcnginxxposed搭建xp模块集群 前言实现逻辑配置内网穿透实现负载均衡 前言 为了能够稳定的采集一些app的详情页数据&#xff0c;就得借助xposed&#xff0c;xposed跟NanoHTTPD配合使用就可以在手机端开启接口服务&#xff0c;直接调用手机端的接口就能获取我们想要的数据&am…

uniapp(切换页面保存上一页的数据,路由传参特殊字符)

uniapp 当前页返回到上一页时&#xff0c;保留当前页的数据&#xff0c;修改上一页对应的数据 let pages getCurrentPages();// #ifdef MP-WEIXIN || APP-PLUSlet currPage pages[pages.length - 1].$vm;let prevPage pages[pages.length - 2].$vm; //上一个页面// #e…

Moonbeam联合Multichain和AWS Startups正式推出Bear Necessities Hackathon黑客松

我们很高兴宣布Bear Necessities Hackathon正式启动。本次黑客松包含7个挑战&#xff0c;超过7万美金的奖池等你来领&#xff01;我们欢迎所有的BUILDers参加&#xff0c;这是开发者们探索Moonbeam并构建跨链用例的机会&#xff01; 本次黑客松由Moonbeam、Multichain和AWS St…

入门JavaScript编程:上手实践四个常见操作和一个轮播图案例

部分数据来源&#xff1a;ChatGPT 简介 JavaScript是一门广泛应用于Web开发的脚本语言&#xff0c;它主要用于实现动态效果和客户端交互。下面我们将介绍几个例子&#xff0c;涵盖了JavaScript中一些常见的操作&#xff0c;包括&#xff1a;字符串、数组、对象、事件等。 例子…

传染病学模型 | SIR 、SEIR传染病学模型

文章目录 SIR传染病学模型SEIR传染病学模型参考资料SIR传染病学模型 SIR模型是一种流行病学模型,用于描述传染病在人群中的传播过程。SIR模型将人群分为三个类别:易感者(Susceptible)、感染者(Infectious)和康复者(Recovered)。三个类别之间的转移可以用以下三个微分方…

测试的分类

1 按照开发阶段&#xff08;软件开发周期&#xff09; 单元测试是对软件的组成单元进行测试。其目的是检验软件基本组成单位的正确性。测试的对象是软件设计的最小单位——模块&#xff0c;故又称为模块测试。集成测试是将程序模块采用适当的集成策略组装起来&#xff0c;对系…

怎么验证文法是否为LL(1)文法

要验证一个文法是否是LL(1)文法&#xff0c;需要进行以下步骤&#xff1a; 消除左递归&#xff1a;如果文法存在左递归&#xff0c;则需要先对其进行消除。 提取左公因子&#xff1a;如果文法存在左公因子&#xff0c;则需要将其提取。 构造FIRST集合&#xff1a;对于每个非终…