图神经网络实战(14)——基于节点嵌入预测链接

ops/2024/10/11 13:26:19/

神经网络实战(14)——基于节点嵌入预测链接

    • 0. 前言
    • 1. 图自编码器
    • 2. 变分图自编码器
    • 3. 实现变分图自编码器
    • 小结
    • 系列链接

0. 前言

我们已经了解了如何使用图神经网络 (Graph Neural Networks, GNN) 生成节点嵌入,我们可以使用这些嵌入执行矩阵分解 (matrix factorization) 完成链接预测任务。本节将介绍两种用于链接预测的 GNN 架构——图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE)。

1. 图自编码器

图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 架构都是 KipfWelling2016 年所提出的。它们分别对应于两种流行的神经网络架构——自编码器 (Autoencoder) 和变分自编码器 (Variational Autoencoder, VAE)。为了便于理解,我们将首先介绍 GAEGAE 由两个模块组成:

  • 编码器 (encoder):一个经典的双层图卷积网络 (Graph Convolutional Network, GCN),使用以下方式计算节点嵌入:
    Z = G C N ( X , A ) Z=GCN(X,A) Z=GCN(X,A)
  • 解码器 (decoder):使用矩阵分解 (matrix factorization) 和 sigmoid 函数 σ σ σ 来近似邻接矩阵 A ^ \hat A A^,从而输出概率:
    A ^ = σ ( Z T Z ) \hat A=\sigma(Z^TZ) A^=σ(ZTZ)

需要注意的是,我们并不是要对节点或图进行分类,而是预测邻接矩阵 A ^ \hat A A^ 中每个元素的概率(介于 01 之间),因此使用两个邻接矩阵元素之间的二进制交叉熵损失(负对数似然)来训练 GAE
L B C E = ∑ i ∈ V , j ∈ V − A i j l o g ( A ^ i j ) − ( 1 − A i j ) l o g ( 1 − A ^ i j ) \mathcal L_{BCE}=\sum_{i\in V,j\in V}-A_{ij}log(\hat A_{ij})-(1-A_{ij})log(1-\hat A_{ij}) LBCE=iV,jVAijlog(A^ij)(1Aij)log(1A^ij)

然而,邻接矩阵通常非常稀疏,这会使 GAE 偏向于预测零值。有两种简单的方法可以修正这一偏差。首先,可以在上述损失函数中增加一个权重,使偏向于 A i i = 1 A_{ii}=1 Aii=1。其次,可以在训练过程中采样较少的零值,使标签更加均衡。
这种架构非常灵活,编码器可以换成其它类型的图神经网络 (Graph Neural Networks, GNN) (如 GraphSAGE、图同构网络 (Graph Isomorphism Network, GIN) 等),多层感知机 (Multilayer Perceptron, MLP) 也可以作为解码器,另一种改进方法是将 GAE 转换为变分图自编码器。

2. 变分图自编码器

图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 之间的区别与自编码器 (Autoencoder) 和变分自编码器 (Variational Autoencoder, VAE) 之间的区别相同。VGAE 不直接学习节点嵌入,而是学习正态分布,然后通过采样生成嵌入。VGAE 也由两个模块组成:

  • 编码器 (encoder):由共享第一层的两个图卷积网络 (Graph Convolutional Network, GCN) 组成。其目标是学习每个潜正态分布的参数,均值 μ μ μ (由 G C N μ GCN_μ GCNμ 学习)和方差 σ 2 σ^2 σ2 (在实践中通过 G C N σ GCN_σ GCNσ 学习其对数形式)
  • 解码器 (decoder):使用重参数化技巧 (reparametrization trick),从学习到的分布 ( μ , σ 2 ) (μ, σ^2) (μ,σ2) 中采样嵌入值 z i z_i zi, 。然后,它使用潜变量之间的内积来近似邻接矩阵 A ^ = σ ( Z T Z ) \hat A= σ(Z^TZ) A^=σ(ZTZ)

对于 VGAE,确保编码器的输出服从正态分布非常重要,因此需要在损失函数中添加一个新项,Kullback-Leibler 散度 (KL 散度),它用于测量两个分布之间的差异。VGAE 的总体损失如下,也称为证据下界 (evidence lower bound, ELBO):
L E L B O = L B C E − K L [ q ( Z ∣ X , A ) ∣ ∣ p ( Z ) ] \mathcal L_{ELBO}=\mathcal L_{BCE}-KL[q(Z|X,A)||p(Z)] LELBO=LBCEKL[q(ZX,A)∣∣p(Z)]
其中, q ( Z ∣ X , A ) q(Z|X,A) q(ZX,A) 表示编码器, p ( Z ) p(Z) p(Z) Z Z Z 的先验分布。通常可以使用ROC 曲线下面积 (area under the ROC, AUROC) 和平均精度 (average precision, AP) 这两个指标来评估模型的性能。
接下来,我们使用 PyTorch Geometric 实现 VGAE

3. 实现变分图自编码器

变分图自编码器 (Variational Graph Autoencoder, VGAE) 与其它类型的图神经网络 (Graph Neural Networks, GNN) (如 GraphSAGE、图同构网络 (Graph Isomorphism Network, GIN) 等)实现有两个主要区别:

  • 对数据集进行预处理,随机删除一些链接以进行预测
  • 创建一个编码器模型,并将其添加到 VGAE 类中,而不是直接从头开始实现 VGAE

接下来,使用 PyTorch Geometric (PyG) 构建 VGAE 模型。

(1) 首先,导入所需的库,并定义设备:

import numpy as np
import torch
import matplotlib.pyplot as plt
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoiddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(2) 创建一个 transform 对象,对输入特征进行归一化处理,将张量转移到预定义的设备中,并随机分割链接(在本节中,我们按照 85: 5:10 的比例进行拆分),将 add_negative_train_samples 参数设置为 False,因为模型已经执行了负采样,所以数据集中不需要负采样:

transform = T.Compose([T.NormalizeFeatures(),T.ToDevice(device),T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),
])

(3) 使用定义的 transform 对象加载 Cora 数据集:

dataset = Planetoid('.', name='Cora', transform=transform)

(4) RandomLinkSplit 方法会按预定比例拆分生成训练/验证/测试集,并存储这些数据集:

train_data, val_data, test_data = dataset[0]

(5) 接下来,实现编码器。首先,需要导入 GCNConv 和 VGAE

from torch_geometric.nn import GCNConv, VGAE

声明一个新类,在这个类中,需要三个图卷积网络 (Graph Convolutional Network, GCN) 层,一个作为共享层、一个用于近似均值 μ μ μ,第三个用于近似方差值(实践中使用对数标准差, log ⁡ σ \log\sigma logσ):

class Encoder(torch.nn.Module):def __init__(self, dim_in, dim_out):super().__init__()self.conv1 = GCNConv(dim_in, 2 * dim_out)self.conv_mu = GCNConv(2 * dim_out, dim_out)self.conv_logstd = GCNConv(2 * dim_out, dim_out)def forward(self, x, edge_index):x = self.conv1(x, edge_index).relu()return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

(6) 初始化 VGAE 并将编码器作为输入,默认情况下,VGAE 使用内积作为解码器:

model = VGAE(Encoder(dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

(7)train() 方法中,首先使用 model.encode() 计算嵌入矩阵 Z Z Z,此函数从学习到的分布中对样本嵌入进行采样。然后,使用 model.recon_loss() (二进制交叉熵损失)和 model.kl_loss() (KL 散度) 计算 ELBO 损失。解码器会被隐式调用来计算交叉熵损失:

def train():model.train()optimizer.zero_grad()z = model.encode(train_data.x, train_data.edge_index)loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()loss.backward()optimizer.step()return float(loss)

(8) test() 函数只需调用 VGAE 的专用方法:

@torch.no_grad()
def test(data):model.eval()z = model.encode(data.x, data.edge_index)return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

(9) 对模型进行 301epoch 的训练,并打印 AUCAP 指标:

for epoch in range(301):loss = train()val_auc, val_ap = test(test_data)if epoch % 50 == 0:print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}') 

输出结果如下所示:

Epoch  0 | Loss: 3.4412 | Val AUC: 0.6842 | Val AP: 0.7043
Epoch 50 | Loss: 1.3321 | Val AUC: 0.6628 | Val AP: 0.6881
Epoch 100 | Loss: 1.1690 | Val AUC: 0.7512 | Val AP: 0.7526
Epoch 150 | Loss: 1.0348 | Val AUC: 0.8173 | Val AP: 0.8128
Epoch 200 | Loss: 0.9980 | Val AUC: 0.8415 | Val AP: 0.8364
Epoch 250 | Loss: 0.9698 | Val AUC: 0.8576 | Val AP: 0.8457
Epoch 300 | Loss: 0.9339 | Val AUC: 0.8727 | Val AP: 0.8620

(10) 在测试集上对模型进行评估:

test_auc, test_ap = test(test_data) 
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')# Test AUC: 0.8727 | Test AP 0.8620

(11) 手动计算近似邻接矩阵 A ^ \hat A A^

z = model.encode(test_data.x, test_data.edge_index) 
Ahat = torch.sigmoid(z @ z.T)
print(Ahat)
'''
tensor([[0.8468, 0.5072, 0.7254,  ..., 0.7016, 0.8674, 0.8545],[0.5072, 0.8120, 0.7991,  ..., 0.4572, 0.6988, 0.6898],[0.7254, 0.7991, 0.8623,  ..., 0.5731, 0.8622, 0.8496],...,[0.7016, 0.4572, 0.5731,  ..., 0.6582, 0.6973, 0.6925],[0.8674, 0.6988, 0.8622,  ..., 0.6973, 0.9259, 0.9155],[0.8545, 0.6898, 0.8496,  ..., 0.6925, 0.9155, 0.9051]],device='cuda:0', grad_fn=<SigmoidBackward0>)
'''

VGAE 的训练速度很快,输出结果也很容易理解,但我们已经知道 GCN 并不是最具表达能力的运算符。为了提高模型的表达能力,我们需要采用更好的技术。

小结

链接预测可以帮助我们发现隐藏的关联规律,从而为网络分析、推荐系统等问题提供有效的解决方案。在本节中,介绍了如何使用图神经网络 (Graph Neural Networks, GNN) 实现链接预测,学习了基于节点嵌入的链接预测技术,包括图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE),并使用边级随机分割和负采样在 Cora 数据集上实现了 VGAE 模型。

系列链接

神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
神经网络实战(2)——图论基础
神经网络实战(3)——基于DeepWalk创建节点表示
神经网络实战(4)——基于Node2Vec改进嵌入质量
神经网络实战(5)——常用图数据集
神经网络实战(6)——使用PyTorch构建图神经网络
神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
神经网络实战(9)——GraphSAGE详解与实现
神经网络实战(10)——归纳学习
神经网络实战(11)——Weisfeiler-Leman测试
神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
神经网络实战(13)——经典链接预测算法


http://www.ppmy.cn/ops/50026.html

相关文章

Flink的简单学习五

一 动态表与连续查询 1.1 动态表 1.是flink的支持流数据Table API 和SQL的核心概念。动态表随时间的变化而变化 2.在流上面定义的表在内部是没有数据的 1.2 连续查询 1.永远不会停止&#xff0c;结果是一张动态表 二 Flink SQL 2.1 sql行 1.先启动启动flink集群 yarn-see…

springboot接入springai-openAi代理和智谱ai调用示例

这里写自定义目录标题 背景配置具体代码总结 背景 一说到调用openAI的api或者做一些小项目&#xff0c;大部分例子都是python或者node实现的&#xff0c;后来发现spring出了对于openai的支持框架&#xff0c;所以尝试用一用。这里是SpringAI的地址&#xff0c;有兴趣的可以去官…

m4s转mp3——B站缓存视频提取音频

前言 しかのこのこのここしたんたん&#xff08;鹿乃子乃子虎视眈眈&#xff09;非常之好&#xff0c;很适合当闹钟&#xff0c;于是缓存了视频&#xff0c;想提取音频为mp3 直接改后缀可乎&#xff1f;格式转换工具&#xff1f; 好久之前有记录过转MP4的&#xff1a; m4s转为…

Kafka之ISR机制的理解

文章目录 Kafka的基本概念什么是ISRISR的维护机制ISR的作用ISR相关配置参数同步过程示例代码总结 Kafka中的ISR&#xff08;In-Sync Replicas同步副本&#xff09;机制是确保数据高可用性和一致性的核心组件。 Kafka的基本概念 在Kafka中&#xff0c;数据被组织成主题&#xf…

单调栈(续)、由斐波那契数列讲述矩阵快速降幂技巧

在这里先接上一篇文章单调栈&#xff0c;这里还有单调栈的一道题 题目一&#xff08;单调栈续&#xff09; 给定一个数组arr&#xff0c; 返回所有子数组最小值的累加和 就是一个数组&#xff0c;有很多的子数组&#xff0c;每个数组肯定有一个最小值&#xff0c;要把所有子…

前端框架是什么

前端框架是预先编写好的JavaScript代码集合&#xff0c;旨在帮助开发者快速搭建Web应用程序的界面和交互逻辑。以下是一些常见的前端框架&#xff0c;按照字母顺序排列&#xff0c;并简要介绍其特点&#xff1a; Angular 由Google开发&#xff0c;原名AngularJS&#xff0c;后…

EVTOL垂直起降-变化就在空气中

混合动力垂直起降&#xff08;eVTOL&#xff09;飞行器有能力改变空中交通生态系统。了解航空运输面临的挑战以及公司如何利用新机遇。 介绍 一个世纪前&#xff0c;航空先驱格伦柯蒂斯&#xff08;Glenn Curtiss&#xff09;首次推出了自动飞机&#xff0c;这是一种带有可拆…

Spring Boot中使用logback出现LOG_PATH_IS_UNDEFINED文件夹

1.首先查看&#xff0c;application.properties 文件是否按格式编写 logging.pathmylogs logging.configclasspath:logback-spring.xml2.查看 logback-spring.xml <springProperty scope"context" name"LOG_HOME" source"logging.path"/> …