【图神经网络】Pytorch图神经网络库——PyG创建消息传递网络

news/2024/11/24 16:30:04/

PyG创建消息传递网络

  • 消息传递基类:MessagePassing
  • GCN层的实现
  • 实现Edge Convolution
  • 内容来源:

将卷积算子推广到不规则域通常表示为邻域聚合或消息传递方案。在第 (k−1)(k-1)(k1)层中节点 iii的节点特征用 xi(k−1)∈RF\mathrm{x}_{i}^{(k-1)}\in \mathbb{R}^Fxi(k1)RF表示,从节点 jjj到节点 iii的边特征用 ej,i∈RD\mathrm{e}_{j,i}\in \mathbb{R}^Dej,iRD表示,消息传递图神经网络可以用以下公式描述:
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))\mathrm{x}_i^{(k)}=\gamma ^{(k)}(\mathrm{x}_i^{(k-1)},\Box_{j\in \mathcal{N}(i)}\phi^{(k)}(\mathrm{x}_i^{(k-1)},\mathrm{x}_j^{(k-1)},\mathrm{e}_{j,i}))xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))
其中 □\Box表示一个可微、具有对称性的函数,比如sum,max,mean, γ\gammaγϕ\phiϕ表示不同的函数,比如MLPs。

消息传递基类:MessagePassing

PyG提供了消息传递基类,用于创建GNN自动化的消息传递机制。用户只需要定义函数γ\gammaγϕ\phiϕ,分别表示message()update()。聚合操作有aggr="add", aggr="mean" or aggr="max"等。

下面是一些相关方法的简介:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2):定义了一个聚集机制,三个参数分别为:聚集方式,消息传递方向以及沿哪个维度进行传播。
MessagePassing.propagate(edge_index, size=None, **kwargs):首次调用开始传播消息。获取边索引edge_index和所有用于构造消息和更新节点嵌入的附加数据。这个函数不但可以用于方阵,而且也可以用于二分图等非方阵图,但是需要传递size参数表明矩阵形状size=(N, M)
MessagePassing.message(...):构造消息到节点iii,但是根据传播方向有两种情况,如果边方向是(j,i)(j,i)(j,i)flow="source_to_target",即边是jjj指向iii,而且消息流向是源节点到目的节点,或者相反。通常将中心节点表示为iii,邻居节点表示为jjj
MessagePassing.update(aggr_out, ...):更新每个节点iii的嵌入向量,接受聚合的输出作为第一个参数以及最初传递给propagate()的任何参数。

总之,PyG要么是直接调用nn里面的层,或者自己实现网络层。调用nn里面的层在上次简介介绍过了,下面就看一下如何使用MessagePassing这个基类继承实现GCN和EdgeConv层

GCN层的实现

GCN层的数学定义如下:
xi(k)=∑j∈N(i)∪{i}1deg(i)⋅deg(j)⋅(ΘT⋅xj(k−1))+b\mathrm{x}_i^{(k)}=\sum_{j\in \mathrm{N}(i)\cup \{i\} }\frac{1}{\sqrt{deg(i)}\cdot \sqrt{deg(j)}}\cdot (\Theta^T\cdot \mathrm{x}_j^{(k-1)})+\mathrm{b}xi(k)=jN(i){i}deg(i)deg(j)1(ΘTxj(k1))+b
其中邻居节点的特征首先通过权重矩阵Θ\ThetaΘ进行变换,然后使用它们的度进行标准化(normalized),最终加和(summed up)。将其步骤写为如下几步:

  1. 向邻接矩阵(adjacency matrix)添加自循环(self-loops);
  2. 线性变换节点特征矩阵;
  3. 计算归一化系数;
  4. ϕ\phiϕ中Normalize节点特性
  5. 对相邻节点特征进行归纳(add聚合)
  6. 应用一个最终的偏差向量

步骤1-3通常是在消息传递之前计算的。使用MessagePassing基类可以很容易地处理步骤4-5。完整的GCN层实现如下:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add')  # "Add" aggregation (Step 5).self.lin = Linear(in_channels, out_channels, bias=False)self.bias = Parameter(torch.Tensor(out_channels))self.reset_parameters()def reset_parameters(self):self.lin.reset_parameters()self.bias.data.zero_()def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x)# Step 3: Compute normalization.row, col = edge_indexdeg = degree(col, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]# Step 4-5: Start propagating messages.out = self.propagate(edge_index, x=x, norm=norm)# Step 6: Apply a final bias vector.out += self.biasreturn outdef message(self, x_j, norm):# x_j has shape [E, out_channels]# Step 4: Normalize node features.return norm.view(-1, 1) * x_j

GCNConv通过“add”传播继承自MessagePassing。该层的所有逻辑都发生在其forward()方法中。这里,我们首先使用torch_geometric.utils.add_self_loops()函数将自循环添加到边索引中(步骤1),以及通过调用torch.nn.Linear实例来线性变换节点特征(步骤2)。
对于每一个节点iii,归一化系数由节点度deg(i)deg(i)deg(i)得出,对于每个边(j,i)∈E(j,i) \in \mathcal{E}(j,i)E,转换到了1/(deg(i)⋅deg(j))1/(\sqrt{deg(i)}\cdot \sqrt{deg(j)})1/(deg(i)deg(j)),结果保存在形为[ num _ edge,](步骤3)的张量norm中。

然后,我们调用propagate(),该函数在内部调用message()aggregate()update()。我们将节点嵌入xxx和标准化系数norm作为消息传播的其他参数(additional arguments)。

message()函数中,我们需要通过norm规范相邻节点的特征x_j。这里,x_j表示lifted张量,它包含每个边的源节点特征,即每个节点的邻居。通过将_i_j附加到变量名称,可以自动提升节点特征。事实上,任何张量都可以这样转换,只要它们包含源节点或目标节点特征

这就是创建一个简单的消息传递层所需的全部内容。可以将此层用作深层体系结构的构建块。初始化和调用它非常简单:

conv = GCNConv(16, 32)
x = conv(x, edge_index)

实现Edge Convolution

Edge Convolution通过下式进行点云的处理:
xi(k)=max⁡j∈N(i)hΘ(xi(k−1),xj(k−1)−xi(k−1))\mathrm{x}_i^{(k)}=\max _{j\in \mathrm{N}(i)}h_\Theta(\mathrm{x}_i^{(k-1)},\mathrm{x}_j^{(k-1)}-\mathrm{x}_i^{(k-1)})xi(k)=jN(i)maxhΘ(xi(k1),xj(k1)xi(k1))
其中hΘh_\ThetahΘ表示一个MLP,类比GCN,我们使用MessagePassing 来实现这一层,同时使用max聚合。实现代码如下:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassingclass EdgeConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='max') #  "Max" aggregation.self.mlp = Seq(Linear(2 * in_channels, out_channels),ReLU(),Linear(out_channels, out_channels))def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]return self.propagate(edge_index, x=x)def message(self, x_i, x_j):# x_i has shape [E, in_channels]# x_j has shape [E, in_channels]tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]return self.mlp(tmp)

message()函数内部,对于每个边(j,i)∈E(j,i) \in \mathcal{E}(j,i)E,我们使用self.mlp将节点特征xix_ixi和相对源节点特征xj−xix_j-x_ixjxi进行转换。

根据定义式,就是首先使用MLP处理输入,在使用Max的aggregation操作。其中,MLP需要自己定义,然后aggregation操作只需要在初始化父类时传入参数aggr='max'即可,非常的方便。

对于边的卷积操作,实际上是动态的卷积,在每一层使用最近邻居在特征空间进行重计算。幸运的是,PyG有一个使用GPU加速的K-NN图产生的方法torch_geometric.nn.pool.knn_graph()

from torch_geometric.nn import knn_graphclass DynamicEdgeConv(EdgeConv):def __init__(self, in_channels, out_channels, k=6):super().__init__(in_channels, out_channels)self.k = kdef forward(self, x, batch=None):edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)return super().forward(x, edge_index)

knn_graph方法计算最近邻图,然后进一步调用EdgeConvforward函数进行传播。至此,就可以为DynamicEdgeConv留下一个干净的调用接口。

conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)

内容来源:

[1] CREATING MESSAGE PASSING NETWORKS


http://www.ppmy.cn/news/6161.html

相关文章

300行HTML+CSS+JS代码实现动态圣诞树

文章目录1. 前言2. 效果展示3. 准备🍑 下载编译器🍑 下载插件4. 源码🍑 HTML🍑 JS🍑 CSS5. 结语1. 前言 一年一度的圣诞节和考研即将来临,那么这篇文章将用到前端三大剑客 HTML CSS JS 来实现动态圣诞树…

前端(htmlCSSJavaScript)基础

关于前端更多知识请关注官网:w3school 在线教程全球最大的中文 Web 技术教程。https://www.w3school.com.cn/ 1.HTML HTML(HyperText Markup Language):超文本标记语言 超文本:超越了文本的限制,比普通文本更强大。除了文字信息…

硬盘恢复工具软件哪个好?分享这些硬盘数据恢复工具软件

您刚刚删除了一些非常重要的文件! 不要惊慌……您仍然有很大的机会可以以很少甚至免费的方式取回它们。 我们正在深入研究当今最好的硬盘恢复软件。 我们认为有一个明显的赢家,但我们提供了一些其他选项,以防您需要更高级的功能或使用不同…

c# winform 重启自己 简单实现

1.情景 有些时候,系统会出问题,问题原因很难排除,但是重启问题就能修正,这时候我们就需要在一个检测到问题的时机,让系统进行一次重启。 2.代码 using System; using System.Windows.Forms;namespace 程序重启自己 …

PHP 精度计算问题(精确算法)

1. PHP 中的精度计算问题 当使用 php 中的 -*/ 计算浮点数时, 可能会遇到一些计算结果错误的问题 这个其实是计算机底层二进制无法精确表示浮点数的一个 bug, 是跨域语言的, 比如 js 中的 舍入误差 所以大部分语言都提供了用于精准计算的类库或函数库, 比如 php 中的 bc 高精…

Kaggle手写识别-卷积神经网络Top6%-代码详解

目录 1. Introduction 简介 2. Data preparation 数据准备 2.1 Load data 加载数据 2.2 Check for null and missing values 检查空值和缺失值 2.3 Normalization 规范化 2.4 Reshape 重塑 2.5 Label encoding 标签编码 2.6 Split training and valdiation set 拆分训…

国民技术 N32G45xxxx 编码器encoder

最近项目用到了一些单片机的编码器功能,有以下几种: 协议模式(串口,485-RTU,IIC等); 脉冲模式(2相,3相等); 而这两种模式的编码器分别具有不同的优劣点。 优点: 协议模式: 在经过实际测试后,发现协议模式的编码器,操作比较简单,通常只需要通过对应的通 信接口接收…

Python基础(十一)面向对象

目录 1. 简介 ①面向对象相关概念 ②面向对象三大特性 2.基本操作 2.1 类 2.2 对象 2.3 继承 1. 简介 面向对象(OOP)是一种对现实世界理解和抽象的方法,对象的含义是指在现实生活中能够看得见摸得着的具体事物,一句比较经…