GNN动手实践(二):复现图注意力网络GAT

news/2025/1/15 22:08:32/

参考论文:Graph Attention Networks

一.前言

GAT(图注意力网络)是GNNs中重要的SOTA模型,该模型是从空域角度来进行定义,能够用消息传递范式来进行解释。GAT与GCN最大的不同便是它在图节点邻域聚合的过程中引入了注意力机制来计算邻居对当前正在聚合的节点的重要程度。本文的内容包括:图注意力网络的架构介绍、基于PyG来复现GAT模型。

二.GAT架构介绍

正如第一节中介绍的那样,GAT最大的贡献便是将注意力机制引入到图卷积中来,下面先给出该模型的架构图:

GAT_Architecture

从该图可以看出,GAT在聚合过程中,需要计算1阶邻居节点对当前节点的重要程度,即 α ⃗ i j \vec{\alpha}_{ij} α ij,然后进行加权求和。下面是该模型对应的消息传递范式的数学形式:
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = softmax ⁡ j ( e i j l ) = exp ⁡ ( e i j l ) ∑ k ∈ N i exp ⁡ ( e i k l ) e i j l = L e a k y R e L U ( a ( W h i ( l ) ∥ W h j ( l ) ) ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} &=\operatorname{softmax}_{j}\left(e_{ij}^{l}\right)=\frac{\exp \left(e_{ij}^{l}\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(e_{ik}^{l}\right)} \\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(a (W h_{i}^{(l)} \| W h_{j}^{(l)})\right)\end{aligned} hi(l+1)αijleijl=jN(i)αi,jW(l)hj(l)=softmaxj(eijl)=kNiexp(eikl)exp(eijl)=LeakyReLU(a(Whi(l)Whj(l)))
其中 h i ( l ) h_i^{(l)} hi(l) h j ( l ) h_j^{(l)} hj(l)是GAT模型中第 l l l层的节点特征, a a a是一个单层前馈神经网络, ∥ \| 表示向量的拼接操作, W W W是权重矩阵, N ( i ) \mathcal{N}(i) N(i)表示节点 i i i的1阶邻域。

另外,作者在论文中使用了多头注意力(Multi-Head Attention)机制,即可以将上述聚合公式扩展为如下形式:
h i ( l + 1 ) = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{ij}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=k=1KσjNiαijkWkhj(l)
其中 K K K表示注意力头的个数。

需要注意的是,若在最后一层使用多头注意力机制,则使用求平均来代替拼接操作,即:
h i ( l + 1 ) = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=σK1k=1KjNiαijkWkhj(l)

三.复现工作

3.1 复现GAT模型

对于GAT模型,本文采用PyG来对其进行复现操作。若对PyG中如何实现消息传递神经网络不怎么了解的可以参见博主之前的博文《PyG教程(6):自定义消息传递网络》。

GAT模型包含两个图注意力卷积层的GAT,两个卷积层之间的非线性激活为ELU,该模型的实现源码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loopsclass GATConv(MessagePassing):def __init__(self, in_feats, out_feats, alpha, drop_prob, num_heads):super().__init__(aggr="add")self.drop_prob = drop_probself.num_heads = num_headsself.out_feats = out_feats // num_headsself.lin = nn.Linear(in_feats, self.out_feats *self.num_heads, bias=False)self.a = nn.Linear(2*self.out_feats, 1)self.leakrelu = nn.LeakyReLU(alpha)def forward(self, x, edge_index):edge_index, _ = add_remaining_self_loops(edge_index)# Whh = self.lin(x)h_prime = self.propagate(edge_index, x=h)return h_primedef message(self, x_i, x_j, edge_index_i):x_i = x_i.view(-1, self.num_heads, self.out_feats)x_j = x_j.view(-1, self.num_heads, self.out_feats)# a(Wh_i, Wh_j)e = self.a(torch.cat([x_i, x_j], dim=-1)).permute(1, 0, 2)# LeakReLU(a(Wh_i, Wh_j))e = self.leakrelu(e.permute(1, 0, 2))# softmax(e_{ij})alpha = softmax(e, edge_index_i)alpha = F.dropout(alpha, self.drop_prob, self.training)return (x_j * alpha).view(x_j.size(0), -1)class GAT(nn.Module):def __init__(self, in_feats, hidden_feats, y_num,alpha=0.2, drop_prob=0., num_heads=[1, 1]):super().__init__()self.drop_prob = drop_probself.gatconv1 = GATConv(in_feats, hidden_feats, alpha, drop_prob, num_heads[0])self.gatconv2 = GATConv(hidden_feats, y_num, alpha, drop_prob, num_heads[1])def forward(self, x, edge_index):x = self.gatconv1(x, edge_index)x = F.elu(x)x = F.dropout(x, self.drop_prob, self.training)out = self.gatconv2(x, edge_index)return F.log_softmax(out, dim=1)if __name__ == "__main__":conv = GATConv(in_feats=64, out_feats=64, alpha=0.2,num_heads=8, drop_prob=0.2)x = torch.rand(4, 64)edge_index = torch.tensor([[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)x = conv(x, edge_index)print(x.shape)

3.2 实验

3.2.1 实验数据集

本文以Cora数据集为例进行实验,该数据集为一个论文引用网络,包含了2708篇论文,每篇论文都由1433维的词向量表示。该论文引用网络包含5429条边,表示论文间的引用关系。数据集中的论文分为7个类别。

3.2.2 超参配置

本文实验的超级参数来源于GAT论文,具体如下表所示:

ParameterValue
dropout rate0.6
weight_decay5e-4
learning rate0.01
hidden size64
num_head[8,1],即第一个卷积层包含8个注意力头,第二层包含1个
epochs300

3.2.3 实验结果展示

在实验过程中,使用训练集进行模型的参数更新,然后使用验证集来筛选最佳的模型,最后将最佳的模型在测试集上进行测评。某次实验的运行结果截图如下:

gat_result
从最后的结果来看,和论文里面的报的相应数据集的结果差不多。当然,限于时间原因,没有细致调参和做一些可视化相关的工作,有感兴趣的小伙伴可以自行研究。

四.结语

完整项目Github地址:GAT
以上便是本文的全部内容,要是觉得不错的话就点个赞或关注一下博主吧,你们的支持是博主继续创作的不解动力,当然若是有任何问题也敬请批评指正!!!


http://www.ppmy.cn/news/498303.html

相关文章

送给刚入职 有压力和迷茫的你

文/北妈 阅读本文需要 2.3分钟 一 一直想写一篇这个话题出来,和大家分享下我遇到这种情况时候的一些处理方法和心理状态,我尽量言简意赅,试着把篇幅缩短。 因为每次写原创,我都不自禁都写特别长,这应该是病&#xff0c…

沐曦与百度飞桨完成兼容性测试,助力计算机视觉应用发展

近日,沐曦集成电路(上海)有限公司(以下简称“沐曦”)的曦思N100人工智能推理GPU与百度飞桨完成 I 级兼容性测试。测试结果显示,双方兼容性表现良好,整体运行稳定。这是沐曦自2022年9月加入“硬件…

沐曦加入龙蜥社区,聚焦技术创新,繁荣开源生态

近日,沐曦集成电路(上海)有限公司(以下简称“沐曦”)正式加入龙蜥社区(OpenAnolis)。 沐曦致力于为异构计算提供高性能 GPU 芯片和解决方案。公司于 2020 年 9 月成立于上海市浦东新区临港新片区…

php压力测试ab.exe,apache ab工具页面压力测试返回结果含义解释

ab是apache自带的一个很好用的压力测试工具,当安装完apache的时候,就可以在bin下面找到ab 参数说明及示例 我们可以模拟100个并发用户,对一个页面发送1000个请求 其中-n代表请求数,-c代表并发数 返回结果: ##首先是apache的版本信…

沐曦加入openKylin,提供全栈GPU芯片及解决方案

近日,沐曦集成电路 (上海) 有限公司(以下简称“沐曦”)签署openKylin社区CLA(Contributor License Agreement 贡献者许可协议),正式加入openKylin开源社区。 沐曦致力于为异构计算提供全栈GPU芯片及解决方案…

MT6762/处理器参数,helio P22芯片资料介绍

MT6762/helio P22具有集成的蓝牙、fm、wlan和gps模块,是一个高度集成的基带平台,包括调制解调器和应用处理子系统,以支持LTE/LTE-A和C2K智能手机应用程序。该芯片集成了ARM Cortex-A53,工作频率可达2.0GHz,并集成了功能强大的多标准视频编解码器。此外,一组广泛的接口 连…

曦力DVD备份及转换软件 Xilisoft DVD Ripper Ultimate 6绿色版

用WinRAR解压即可运行,无需安装。可以拷贝到USB硬盘,便于携带 【下载地址】 http://u.115.com/file/cln6a28w http://www.rayfile.com/zh-cn/files/eddcdc75-c783-11e0-89c1-0015c55db73d/ 曦力DVD备份及转换软件Xilisoft DVD Ripper 是一款DVD备份及转换…

干货 | 诚迈科技联合联发科技、Linaro推出《Make it happen!联发科技曦力X20开发板技术公开课(上海站)》精彩回顾

2016年12月10日,由AlphaSTAR极客社区主办的《Make it happen!联发科技曦力X20开发板技术公开课》在上海完美落幕。联发科技、Linaro及诚迈科技热力助阵,电子发烧友社区、IC咖啡参与协办了此次活动。众多创客思维碰撞,与技术大咖面…