【模型学习之路】手写+分析GAT

server/2024/11/13 1:49:00/

从GNN,到GCN,再到GAT

目录

文章目录

前言

GNN

GCN

GAT

公式

注意力实现

公式对比

多头注意力实现

测试&可视化


前言

读本文前,可以先过一遍【GNN图神经网络】入门到实战完整40讲!同济大佬用大白话的方式从零到一讲解原理基础及代码复现,主打一个通俗易懂!_哔哩哔哩_bilibili的1-12集。

GNN

GNN(Graph Neural Network,图神经网络)是一种专门用于处理图结构数据的深度学习模型,它的核心思想是通过聚合邻居节点的信息来更新每个节点的表示。这种更新过程可以捕捉到节点的局部邻域结构,从而学习到节点、边甚至整个图的高级特征表示。

GCN(Graph Convolutional Network,图卷积网络)是GNN的一种,它通过图卷积操作来更新节点的特征表示。

GCN

稍微来总结一下视频中的内容。

 (m, f)是第层的节点特征矩阵,m是节点个数,f表示每个节点的特征数。

 (f, f’) 是第层的可学习权重矩阵,也就是我们要训练的参数。

(m, m)是无自环邻接矩阵(即邻接矩阵,但是对角线全为0),(m, m)是度矩阵。首先,为了消息的自我传播,我们给它们加上单位矩阵。

节点更新函数为:

其中我们可以令一个变量:

式中是激活函数。

左乘和右乘时为了分别对列和行做标准化。

代码很简单,就是公式的堆叠,不做赘述。

python">import torch
import torch.nn as nn
import torch.nn.functional as Fdef normalized_adjacency(adj):"""输入A, 返回A^ """d = torch.diag(torch.sum(adj, dim=1))a = adj + torch.eye(adj.shape[0])d = d + torch.eye(adj.shape[0])d_inv_sqrt = torch.pow(d, -0.5)a_norm = d_inv_sqrt @ a @ d_inv_sqrtreturn a_normclass GraphConvolution(nn.Module):def __init__(self, in_features, out_features):super(GraphConvolution, self).__init__()self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))self.reset_parameters()def reset_parameters(self):stdv = 1. / (self.weight.size(1) ** 0.5)self.weight.data.uniform_(-stdv, stdv)def forward(self, x, adj):adj = normalized_adjacency(adj)return adj @ (x @ self.weight)class GCN(nn.Module):def __init__(self, in_features, hidden_features, out_features,n_layers, dropout=0.5):super(GCN, self).__init__()self.layers = nn.ModuleList()# 输入层到第一隐藏层self.layers.append(GraphConvolution(in_features, hidden_features))# 隐藏层到其他隐藏层for _ in range(n_layers - 2):self.layers.append(GraphConvolution(hidden_features, hidden_features))# 最后一个隐藏层到输出层self.layers.append(GraphConvolution(hidden_features, out_features))self.dropout = dropoutdef forward(self, x, adj):for layer in self.layers[:-1]:x = F.relu(layer(x, adj))x = F.dropout(x, self.dropout)x = self.layers[-1](x, adj)return F.log_softmax(x, dim=1)

GAT

公式

GAT(Graph Attention Networks,图注意力模型)。

本质上就是GCN应用了注意力机制,即A  成为了一个可变的、要更新的东西。

根据更新方式的不同,主要分为两种:

Global graph attention,就是任意两个点都要进行attention运算。

Mask graph attention,注意力机制的运算只在邻居顶点上进行。

根据实现中矩阵表示方法的不同,又分为密集矩阵(就是平时的矩阵)和稀释矩阵表示。

此外,又有很多种计算注意力系数的方法。

这里采用当年GAT论文中的做法,使用密集矩阵,计算注意力系数的方法与论文中保持一致。

在图注意力网络(GAT)中, 的计算公式以及节点更新公式如下:

计算注意力系

这里,  表示节点 hi  相对于节点  的注意力值,  是共享的可学习参数,||表示向量拼接操作。  是一种激活函数。

(1, f)(f, f’)的维度是(1, f’),两个拼接就会得到(1, 2f’) 是一个列向量,维度是(2f’,1),两者相乘得到一个标量。

对行求softmax。

    节点更新:

指的是与i节点相邻的所有节点形成的集合。

从单头到多头,  是注意力头的数量。

无非就是多个注意力头的结构然后全都concat起来,有时也会采用多个注意力头球均值的方法:

注意力实现

上代码

python">import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fclass GraphAttentionLayer(nn.Module):def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_featuresself.out_features = out_featuresself.alpha = alphaself.concat = concatself.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  # [f, f']nn.init.xavier_uniform_(self.W.data, gain=1.414)self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))  # [2f', 1]nn.init.xavier_uniform_(self.a.data, gain=1.414)self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, input_h, adj):  # input_h: [B, m, f]h = input_h @ self.W  # [B, m, f']m = h.size()[1]a_input = torch.cat([h.repeat(1, 1, m).view(-1, m * m, self.out_features),h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features)  # [B, m, m, 2f']e = self.leakyrelu((a_input @ self.a).squeeze(3))  # [B, m, m, 1] -> [B, m, m]# 如果邻接(adj_ij=1),就用e_ij# 如果不邻接(adj_ij=0),就用-9e15, 之后会被softmax“屏蔽”掉# 这样就只留用了邻接的Global graph attention -> Mask graph attentionzero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)  # [B, m, m]attention = F.softmax(attention, dim=-1)attention = F.dropout(attention, self.dropout)h_prime = attention @ h  # [B, m, m] @ [B, m, f'] -> [B, m, f']if self.concat:return F.elu(h_prime)else:return h_prime

先解释一下这一行:

python">a_input = torch.cat([h.repeat(1, 1, m).view(-1, m * m, self.out_features),h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features)  # [B, m, m, 2f']

这一行十分抽象。在这个代码中中h维度是(B, m, f’)。这里先不看batch_size,我们先假设h的维度是(m, f’),h由m个节点组成,每个节点有f个特征。我们可以先把h写成:

 (m, f’) 显然,这里m=3

第一个量:

h.repeat(1,m),横着重复:

   (m, mf’)

h.view(mm, f’):

​​​​​​​

第二个量:

h.repeat(m,1),竖着重复:

 (mm, f’)

将两者拼起来:

 (mm, 2f’)

最后展开:

View(m,m,2f’):

(m, m, 2f’)

加了batch_size之后一个道理。

芜湖,这样处理之后,就有:

则:

解释完毕

公式对比

 

对比一下,GAT的公式描述是这样的:

在代码里面我们是这样实现GAT的:

再得到a_input,之后有:

写成矩阵形式:

然后用A来掩盖掉e中不相邻的值。这一步不写公式了。(因为我也不知道数学上怎么写)。

更新节点

回顾一下GCN的节点更新公式:

发现其实GAT本质就是用了一个新的权重计算方式。

多头注意力实现

之后多头注意力

python">class MultiHeadGATLayer(nn.Module):def __init__(self, nfeat, nhid, nout, dropout, alpha, nheads):"""Dense version of GAT."""super(MultiHeadGATLayer, self).__init__()self.dropout = dropoutself.attentions = nn.ModuleList([GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)for _ in range(nheads)])self.out = GraphAttentionLayer(nhid * nheads, nout, dropout=dropout, alpha=alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout)x = torch.cat([att(x, adj) for att in self.attentions], dim=2)  # [B, m, n_hid*n_heads]x = F.dropout(x, self.dropout)x = F.elu(self.out(x, adj))return x  # [B, m, n_out]

简单组装一下:

python">class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropout# 由输入到隐藏self.gats = nn.ModuleList([MultiHeadGATLayer(nfeat, nhid, nfeat, dropout=dropout, alpha=alpha, nheads=nheads)for _ in range(nlayers)])  # [B, m, f] -> [B, m, h] -> [B, m, f]self.out = GraphAttentionLayer(nfeat, nclass, dropout=dropout, alpha=alpha, concat=False)  # [B, m, n_class]def forward(self, x, adj):for gat in self.gats:x = gat(x, adj)x = F.dropout(x, self.dropout)x = self.out(x, adj)return F.log_softmax(x, dim=-1)  # [B, m, n_class]

测试&可视化

python"># 生成测试样例
x = torch.randn(1, 3, 2)  # [batch_size, m, f]
adj = torch.ones(1, 3, 3)  # [batch_size, m, m]
model = GAT(2, 4, 2, 0.6, 0.6, 4, 3)
print(model(x, adj).shape)  # [B, m, n_class] [1, 3, 2]modelData = "./demo.pth"
torch.onnx.export(model, (x, adj), modelData)
netron.start(modelData)

感觉还行 


http://www.ppmy.cn/server/140999.html

相关文章

C++顶层const与底层const

顶层const意味着被修饰的对象本身是一个常量。 顶层const可以用来修饰基本数据类型(如int、float等)和自定义类型(如结构体、类等)的对象。 顶层const修饰的对象的值不能被修改,但是该对象可以被赋予另一个值&#xf…

股民情绪识别的LSTM-NBM混合模型

大家好,我是带我去滑雪! 利用之前爬取2023年10月17日至2024年7月13日的65万余条东方财富网的上证指数股吧的股民评论数据,基于jieba库对股民情绪进行识别,在进行中文分词、去除停用词、合并同义词和长短句分离后,对长文…

备忘录模式:保存对象状态的设计模式

1. 引言 在软件开发中,常常需要保存一个对象的状态,以便将来能够恢复到该状态。在某些情况下,这种需求显得尤为重要,例如在撤销操作、版本控制以及游戏进度保存等场景中。备忘录模式(Memento Pattern)正是…

Python 数据可视化详解教程

Python 数据可视化详解教程 数据可视化是数据分析中不可或缺的一部分,它通过图形化的方式展示数据,帮助我们更直观地理解和分析数据。Python 作为一种强大的编程语言,拥有丰富的数据可视化库,如 Matplotlib、Seaborn、Plotly 和 …

react 类组件和函数组件区别

一 类组件需要使用this关键字来访问props和状态,而函数组件则可以直接访问这些值。原来只有类组件可以使用的特性,比如状态和生命周期方法,现在函数组件通过Hooks也可以使用。函数组件通常更简洁,更易于测试和理解。类组件目前仍…

什么是红黑树

红黑树是一种自平衡的二叉查找树,在计算机科学中常用于组织数据,如数字块等,其典型的用途是实现关联数组。以下是对红黑树的详细介绍,以及左旋、右旋、变色等操作的解析: 一、红黑树简介 起源与命名:红黑树…

算法训练(leetcode)二刷第二十一天 | 491. 非递减子序列、*46. 全排列、*47. 全排列 II、D

刷题记录 491. 非递减子序列*46. 全排列*47. 全排列 IID 491. 非递减子序列 leetcode题目地址 题目提供的数据有重复,但结果集中不可有重复组合,且不允许排序,因此需要借助Set或额外的hash表进行标记当前层是否使用了相同元素。 时间复杂度…

SSRF〈2〉

SSRF的进阶 1.Gopher协议的利用 1.gopher协议可以通过url指向指定IP端口发送任意内容&#xff0c;模拟大多数TCP协议&#xff0c;是SSRF中的一把利刃。 gopher协议URL&#xff1a; gopher://<host>:<port>/_<url编码的TCP数据> 这个url编码的TCP数据是goph…