什么是图神经网络?

ops/2025/2/7 22:30:05/

一、概念

        图神经网络(Graph Neural Network, GNN)是一类专门用于处理图结构数据的神经网络。图结构数据广泛存在于各种实际应用中,如社交网络、分子结构、知识图谱等。GNN通过在图的节点和边上进行信息传递和聚合,能够有效地捕捉图结构中的复杂关系和特征

        GNN的输入通常是一个图 G=(V,E),其中 V 是节点集合,E 是边集合。每个节点 v∈V 可能有一个特征向量 ​,每条边 (u,v)∈E 可能有一个特征向量​。

二、核心算法

        GNN的核心思想是通过迭代地更新节点的表示来捕捉图结构中的信息。每一轮迭代(也称为层)包括以下两个步骤:

  • 消息传递(Message Passing):每个节点从其邻居节点接收信息。
  • 节点更新(Node Update):每个节点根据接收到的信息和自身的特征更新其表示。

        假设我们有一个图 G=(V,E),每个节点 v∈V 的特征向量为 ,每条边 (u,v)∈E 的特征向量为 ​。GNN的计算公式可以表示为:

1、消息传递

        其中,N(v)表示节点 v 的邻居节点集合,M是消息传递函数,是节点 v 在第 k 层接收到的消息。

2、节点更新

        其中,U是节点更新函数,是节点 v 在第 k 层的表示。

三、python实现

        这里,我们构建一个create_graph函数来生成一个空手道俱乐部的图(Karate Club Graph),并为每个节点生成一个特征向量(单位矩阵)和标签(根据俱乐部分组)。通过加载 Karate Club 图数据集,我们可以获得一个社交网络图,其中包含 34 个节点和 78 条边。我们为每个节点生成标签(0 或 1),表示节点属于哪个社区(Mr. Hi 或 Officer)。进而基于这份数据进行GNN分类。

python">import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt# 生成一个小的图数据集
def create_graph():# 加载 Karate Club 图数据集,这是一个社交网络图,包含 34 个节点和 78 条边。G = nx.karate_club_graph()features = np.eye(G.number_of_nodes())# 为每个节点生成标签(0 或 1),表示节点属于哪个社区(Mr. Hi 或 Officer)。labels = np.array([G.nodes[i]['club'] == 'Mr. Hi' for i in range(G.number_of_nodes())], dtype=int)return G, features, labels# 定义原始GNN模型
class GNN(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(GNN, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, x, adj):h = F.relu(self.fc1(x))# 使用邻接矩阵 adj 聚合邻居节点的信息。h = torch.matmul(adj, h)h = self.fc2(h)return F.log_softmax(h, dim=1)# 训练和测试函数
def train(model, optimizer, features, labels, adj, train_mask, epochs=10):model.train()for epoch in range(epochs):optimizer.zero_grad()output = model(features, adj)# 计算负对数似然损失loss = F.nll_loss(output[train_mask], labels[train_mask])loss.backward()optimizer.step()print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')def test(model, features, labels, adj, mask):model.eval()with torch.no_grad():output = model(features, adj)pred = output[mask].max(1)[1]acc = pred.eq(labels[mask]).sum().item() / mask.sum().item()return acc# 主函数
# 创建图数据集
G, features, labels = create_graph()
adj = nx.adjacency_matrix(G).todense()
adj = torch.FloatTensor(adj)
features = torch.FloatTensor(features)
labels = torch.LongTensor(labels)# 训练和测试掩码,前 30 个节点用于训练
train_mask = torch.BoolTensor([True if i < 30 else False for i in range(len(labels))])
test_mask = ~train_mask# 初始化模型和优化器
model = GNN(input_dim=features.shape[1], hidden_dim=16, output_dim=2)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 训练模型
train(model, optimizer, features, labels, adj, train_mask)# 测试模型
train_acc = test(model, features, labels, adj, train_mask)
test_acc = test(model, features, labels, adj, test_mask)
print(f'Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}')# 可视化结果
def plot_graph(G, labels, pred=None):pos = nx.spring_layout(G)plt.figure(figsize=(8, 8))nx.draw(G, pos, with_labels=True, node_color=labels, cmap=plt.cm.rainbow, node_size=500, font_color='white')if pred is not None:nx.draw_networkx_nodes(G, pos, node_color=pred, cmap=plt.cm.rainbow, node_size=200, alpha=0.5)plt.show()plot_graph(G, labels.numpy(), pred=model(features, adj).max(1)[1].numpy())

四、总结

        GNN能够直接处理图结构数据。通过端到端的方式进行训练,GNN能够直接从原始图数据中学习特征和表示,这使得它在处理社交网络、分子结构、知识图谱等任务中具有天然的优势。然而,GNN的计算复杂度较高,尤其是在处理大规模图数据时。每一轮迭代都需要进行消息传递和节点更新,这使得GNN的计算量较大,训练和推理速度较慢。在深层GNN中,节点的表示可能会变得过于相似,导致过平滑问题。此外,如果图数据存在噪声或不完整,GNN的性能也会受到影响。

 


http://www.ppmy.cn/ops/156559.html

相关文章

MySQL初学之旅(5)详解查询

目录 1.前言 2.正文 2.1聚合查询 2.1.1count() 2.1.2sum() 2.1.3avg() 2.1.4max() 2.1.5min() 2.1.6总结 2.2分组查询 2.2.1group by字句 2.2.2having字句 2.2.3group by与having的关系 2.3联合查询 2.3.1笛卡尔积 2.3.2内连接 2.3.3外连接 2.3.4自连接 2.3…

OpenGL学习笔记(十):初级光照:材质 Materials

文章目录 材质属性设置材质属性光的属性设置光照属性 在现实世界里&#xff0c;每个物体会对光产生不同的反应。比如&#xff0c;钢制物体看起来通常会比陶土花瓶更闪闪发光&#xff0c;一个木头箱子也不会与一个钢制箱子反射同样程度的光。有些物体反射光的时候不会有太多的散…

使用scikit-learn中的K均值包进行聚类分析

聚类是无监督学习中的一种重要技术&#xff0c;用于在没有标签信息的情况下对数据进行分析和组织。K均值算法是聚类中最常用的方法之一&#xff0c;其目标是将数据点划分为K个簇&#xff0c;使得每个簇内的数据点更加相似&#xff0c;而不同簇之间的数据点差异较大。 准备自定…

C# 添加、替换、提取、或删除Excel中的图片

在Excel中插入与数据相关的图片&#xff0c;能将关键数据或信息以更直观的方式呈现出来&#xff0c;使文档更加美观。此外&#xff0c;对于已有图片&#xff0c;你有事可能需要更新图片以确保信息的准确性&#xff0c;或者将Excel 中的图片单独保存&#xff0c;用于资料归档、备…

双系统共用一个蓝牙鼠标

前言 由于蓝牙鼠标每次只能配置一个系统&#xff0c;每次切换系统后都需要重新配对&#xff0c;很麻烦&#xff0c;双系统共用一个鼠标原理就是通过windows注册表中找到鼠标每次生成的mac地址以及配置&#xff0c;将其转移到linux上。 解决 1. 首先进入linux系统 进行蓝牙鼠…

PostgreSql 函数异常处理

BEGIN 逻辑块 EXCEPTION WHEN 错误码&#xff08;如&#xff1a;unique_violation&#xff09; or others THEN 异常逻辑块 END; 在PL/pgSQL函数中&#xff0c;如果没有异常捕获&#xff0c;函数会在发生错误时直接退出&#xff0c;与其相关的事物也会随之回滚。我们可以通过使…

最大矩阵的和

最大矩阵的和 真题目录: 点击去查看 E 卷 100分题型 题目描述 给定一个二维整数矩阵&#xff0c;要在这个矩阵中选出一个子矩阵&#xff0c;使得这个子矩阵内所有的数字和尽量大&#xff0c;我们把这个子矩阵称为和最大子矩阵&#xff0c;子矩阵的选取原则是原矩阵中一块相互…

python基础入门:2.3字符串高级操作

字符串高级操作 1. 字符串格式化技巧 1.1 f-string&#xff08;Python 3.6&#xff09; 基础用法&#xff1a; name "Alice" age 25 print(f"{name}今年{age}岁") # Alice今年25岁高级格式控制&#xff1a; pi 3.1415926 # 保留两位小数 print(f&…