GNN入门与实践——基于GraphSAGE在Cora数据集上的节点分类研究

ops/2025/3/6 1:16:54/

Hi,大家好,我是半亩花海。本文介绍了图神经网络GNN)中的一种重要算法——GraphSAGE,其通过采样邻居节点聚合信息,能够高效地处理大规模图数据,并通过一个完整的代码示例(包括数据预处理、模型定义、训练过程、验证与测试以及结果可视化)展示了如何在 Cora 数据集上实现节点分类任务。

目录

一、为什么我们需要图神经网络

GraphSAGE%EF%BC%9F-toc" name="tableOfContents" style="margin-left:0px">二、什么是 GraphSAGE

(一)概念

(二)核心思想

(三)数学公式

GraphSAGE%E5%AE%9E%E7%8E%B0-toc" name="tableOfContents" style="margin-left:0px">三、基于Cora数据集的GraphSAGE实现

(一)研究过程

(二)结果分析

GraphSAGE%20%E7%9A%84%E4%BC%98%E5%8A%BF%E4%B8%8E%E6%9C%AA%E6%9D%A5%E5%B1%95%E6%9C%9B-toc" name="tableOfContents" style="margin-left:0px">四、GraphSAGE的优势与未来展望


一、为什么我们需要图神经网络

近年来,随着深度学习的快速发展,神经网络在图像、文本和语音等领域取得了显著的成功。然而,这些传统方法主要适用于欧几里得数据(如图像和序列),而许多现实世界中的数据本质上是图结构的,例如社交网络、分子结构、知识图谱等。传统的神经网络难以直接处理这种非欧几里得数据。

神经网络(Graph Neural Network, GNN的出现为解决这一问题提供了新的思路。它通过建模节点之间的关系,能够有效地捕捉图结构中的复杂模式。GNN 已经在推荐系统、药物发现、交通预测等领域展现出巨大的潜力。

本文将通过一个具体的 GraphSAGE 示例,深入探讨 GNN 的基本原理、实现细节以及其在实际任务中的应用。


GraphSAGE%EF%BC%9F" name="%E4%BA%8C%E3%80%81%E4%BB%80%E4%B9%88%E6%98%AF%20GraphSAGE%EF%BC%9F">二、什么是 GraphSAGE

(一)概念

GraphSAGE(Graph Sample and Aggregation)是一种基于采样的图神经网络算法。与传统的图卷积网络(GCN)不同,GraphSAGE 不依赖于整个图的邻接矩阵进行计算,而是通过邻居节点进行采样和聚合生成节点表示。这种方法使得 GraphSAGE 更加高效且可扩展,尤其适用于大规模图数据

(二)核心思想

  • 采样(Sampling) :为了减少计算开销,GraphSAGE 对每个节点的邻居进行随机采样,而不是使用所有邻居。
  • 聚合(Aggregation) :通过聚合采样邻居的信息,更新目标节点的特征表示。常见的聚合方式包括均值聚合(mean)、最大池化(max-pooling)等。
  • 逐层传播(Layer-wise Propagation) :每一层都会根据前一层的节点表示和邻居信息生成新的节点表示。

(三)数学公式

假设我们有一个图 G=(V, E),其中 V 是节点集合,E 是边集合。对于第 l 层,目标节点 v 的表示 h_{v}^{(l)}​ 可以通过以下公式计算:

h_v^{(l)}=\sigma\left(W^{(l)} \cdot \text {AGGREGATE}\left(\left\{h_u^{(l-1)}, \forall u \in \mathcal{N}(v)\right\}\right)\right)

其中:

  • N(v) 表示节点 v 的邻居集合;
  • AGGREGATE 是聚合函数,例如均值聚合;
  • W(l) 是可学习的权重矩阵;
  • \sigma 是激活函数,例如 ReLU

GraphSAGE%E5%AE%9E%E7%8E%B0" name="%E4%B8%89%E3%80%81%E5%9F%BA%E4%BA%8ECora%E6%95%B0%E6%8D%AE%E9%9B%86%E7%9A%84GraphSAGE%E5%AE%9E%E7%8E%B0" style="background-color:transparent">三、基于Cora数据集的GraphSAGE实现

下面我们将通过一个完整的代码示例,展示如何使用GraphSAGE在Cora数据集上进行节点分类任务。

数据集及源代码链接:PyG-GraphSAGE(直接Download下来就行,好像有一处没加右括号,改正后直接运行main.py即可复现)。

(一)研究过程

1. 数据预处理

首先,我们加载 Cora 数据集并对其进行归一化处理:

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from net import GraphSage
from data import CoraData
from data import CiteseerData
from data import PubmedData
from sampling import multihop_sampling
from collections import namedtuple# 数据集选择
dataset = "cora"
assert dataset in ["cora", "citeseer", "pubmed"]# 层数选择
num_layers = 2
assert num_layers in [2, 3]# 设置输入维度、隐藏层维度和邻居采样数量
if dataset == "cora":INPUT_DIM = 1433  # 输入维度if num_layers == 2:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数else:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 128, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
elif dataset == "citeseer":INPUT_DIM = 3703  # 输入维度if num_layers == 2:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数else:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 128, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
else:INPUT_DIM = 500  # 输入维度if num_layers == 2:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数else:# Note: 采样的邻居阶数需要与GCN的层数保持一致HIDDEN_DIM = [256, 128, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数# 定义超参数
BATCH_SIZE = 16  # 批处理大小
EPOCHS = 10  # 训练轮数
NUM_BATCH_PER_EPOCH = 20  # 每个epoch循环的批次数
if dataset == "citeseer":LEARNING_RATE = 0.1  # 学习率
else:LEARNING_RATE = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"# 数据结构定义
Data = namedtuple('Data', ['x', 'y', 'adjacency_dict', 'train_mask', 'val_mask', 'test_mask'])# 载入数据
if dataset == "cora":data = CoraData().data
elif dataset == "citeseer":data = CiteseerData().data
else:data = PubmedData().data# 数据归一化
if dataset == "citeseer":x = data.x
else:x = data.x / data.x.sum(1, keepdims=True)  # 归一化数据,使得每一行和为1

说明:

  • INPUT_DIM 是节点特征的维度;
  • HIDDEN_DIM 是隐藏层的维度列表;
  • NUM_NEIGHBORS_LIST 是每层采样的邻居数量;
  • BATCH_SIZE 是每次训练时使用的样本数量;
  • EPOCHS 是总的训练轮数;
  • NUM_BATCH_PER_EPOCH 是每个 epoch 中的批次数量;
  • LEARNING_RATE 是学习率;
  • DEVICE 是使用的设备(CPU 或 GPU)。

2. 定义训练、验证、测试集

接下来,我们将数据集划分为训练集、验证集和测试集:

# 定义训练、验证、测试集
train_index = np.where(data.train_mask)[0]
train_label = data.y
val_index = np.where(data.val_mask)[0]
test_index = np.where(data.test_mask)[0]

说明:

  • train_index 是训练集的索引;
  • train_label 是训练集的标签;
  • val_index 是验证集的索引;
  • test_index 是测试集的索引。

3. 实例化模型

我们实例化一个 GraphSAGE 模型,并指定输入维度、隐藏层维度和邻居采样数量:

# 实例化模型
model = GraphSage(input_dim=INPUT_DIM,hidden_dim=HIDDEN_DIM,num_neighbors_list=NUM_NEIGHBORS_LIST,aggr_neighbor_method="mean",aggr_hidden_method="sum"
).to(DEVICE)print(model)

说明:

  • input_dim 是节点特征的维度;
  • hidden_dim 是隐藏层的维度列表;
  • num_neighbors_list 是每层采样的邻居数量;
  • aggr_neighbor_method 是邻居聚合的方式(例如均值聚合);
  • aggr_hidden_method 是隐藏层聚合的方式(例如求和)。

4. 定义损失函数和优化器

我们使用交叉熵损失函数和 Adam 优化器来训练模型:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)

说明:

  • criterion 是交叉熵损失函数;
  • optimizer 是 Adam 优化器,带有权重衰减(L2 正则化)。

5. 定义训练函数

训练过程分为以下几个步骤:

(1)采样邻居:对每个批次的节点进行多跳采样,获取其邻居节点的特征。

(2)前向传播:将采样得到的节点特征送入模型,计算节点表示。

(3)损失计算:使用交叉熵损失函数计算损失,并通过反向传播更新模型参数。

# 定义训练函数
def train():train_losses = []train_acces = []val_losses = []val_acces = []model.train()  # 训练模式for e in range(EPOCHS):train_loss = 0train_acc = 0val_loss = 0val_acc = 0if e % 5 == 0:optimizer.param_groups[0]['lr'] *= 0.1  # 学习率衰减for batch in range(NUM_BATCH_PER_EPOCH):  # 每个epoch循环的批次数# 随机从训练集中抽取batch_size个节点(batch_size,num_train_node)batch_src_index = np.random.choice(train_index, size=(BATCH_SIZE,))# 根据训练节点提取其标签(batch_size,num_train_node)batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)# 进行多跳采样(num_layers+1,num_node)batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)# 根据采样的节点id构造采样节点特征(num_layers+1,num_node,input_dim)batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]# 送入模型开始训练batch_train_logits = model(batch_sampling_x)# 计算损失loss = criterion(batch_train_logits, batch_src_label)train_loss += loss.item()# 更新参数optimizer.zero_grad()loss.backward()  # 反向传播计算参数的梯度optimizer.step()  # 使用优化方法进行梯度更新# 计算训练精度_, pred = torch.max(batch_train_logits, dim=1)correct = (pred == batch_src_label).sum().item()acc = correct / BATCH_SIZEtrain_acc += accvalidate_loss, validate_acc = validate()val_loss += validate_lossval_acc += validate_accprint("Epoch {:03d} Batch {:03d} train_loss: {:.4f} train_acc: {:.4f} val_loss: {:.4f} val_acc: {:.4f}".format(e, batch, loss.item(), acc, validate_loss, validate_acc))train_losses.append(train_loss / NUM_BATCH_PER_EPOCH)train_acces.append(train_acc / NUM_BATCH_PER_EPOCH)val_losses.append(val_loss / NUM_BATCH_PER_EPOCH)val_acces.append(val_acc / NUM_BATCH_PER_EPOCH)# 测试test()res_plot(EPOCHS, train_losses, train_acces, val_losses, val_acces)

说明:

  • train() 函数负责训练模型,记录训练和验证的损失和准确率。
  • multihop_sampling 函数用于对节点进行多跳采样。
  • model 函数负责前向传播,计算节点表示。
  • criterion 函数计算损失。
  • optimizer 函数更新模型参数。
  • validate() 函数用于验证模型在验证集上的性能。
  • test() 函数用于测试模型在测试集上的性能。
  • res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线。

6. 定义验证与测试函数

在验证和测试阶段,我们关闭梯度计算,并评估模型在验证集和测试集上的性能:

# 定义测试函数
def validate():model.eval()  # 测试模式with torch.no_grad():  # 关闭梯度val_sampling_result = multihop_sampling(val_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)val_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in val_sampling_result]val_logits = model(val_x)val_label = torch.from_numpy(data.y[val_index]).long().to(DEVICE)loss = criterion(val_logits, val_label)predict_y = val_logits.max(1)[1]accuarcy = torch.eq(predict_y, val_label).float().mean().item()return loss.item(), accuarcy# 定义测试函数
def test():model.eval()  # 测试模式with torch.no_grad():  # 关闭梯度test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]test_logits = model(test_x)test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)predict_y = test_logits.max(1)[1]accuarcy = torch.eq(predict_y, test_label).float().mean().item()print("Test Accuracy: ", accuarcy)

说明:

  • res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线,并保存图像。

7. 可视化训练与验证过程

为了直观地观察模型在训练和验证过程中的表现,我们通过绘制损失和准确率曲线来分析模型的收敛性和性能。这段代码实现了训练损失、训练准确率、验证损失和验证准确率的可视化,并将结果保存为图像文件。

def res_plot(epoch, train_losses, train_acces, val_losses, val_acces):epoches = np.arange(0, epoch, 1)plt.figure()ax = plt.subplot(1, 2, 1)# 画出训练结果plt.plot(epoches, train_losses, 'b', label='train_loss')plt.plot(epoches, train_acces, 'r', label='train_acc')# plt.setp(ax.get_xticklabels())plt.legend()plt.subplot(1, 2, 2, sharey=ax)# 画出训练结果plt.plot(epoches, val_losses, 'k', label='val_loss')plt.plot(epoches, val_acces, 'g', label='val_acc')plt.legend()plt.savefig('res_plot.jpg')plt.show()

main函数:

# main函数,程序入口
if __name__ == '__main__':train()

(二)结果分析

(1)运行结果

(2)准确与损失率曲线 

从曲线上可以看出,整体准确率比较高且趋于稳定,但经充分训练之后,val_loss值仍然均位于1以上,可能与该模型的学习率过高、数据集处理不当、邻居采样不足等问题,所以此实例demo有待改进。 

GraphSAGE%20%E7%9A%84%E4%BC%98%E5%8A%BF%E4%B8%8E%E6%9C%AA%E6%9D%A5%E5%B1%95%E6%9C%9B" name="%E5%9B%9B%E3%80%81GraphSAGE%20%E7%9A%84%E4%BC%98%E5%8A%BF%E4%B8%8E%E6%9C%AA%E6%9D%A5%E5%B1%95%E6%9C%9B" style="background-color:transparent">四、GraphSAGE的优势与未来展望

通过上述实验,我们可以看到GraphSAGE在Cora数据集上的表现非常出色。相比于传统的GCN,GraphSAGE的采样机制使其能够更好地扩展到大规模图数据,同时保持较高的分类精度。

(1)优势

  • 高效性 :通过采样邻居节点,避免了对整个图的计算,显著降低了时间和空间复杂度。
  • 灵活性 :支持多种聚合方式,可以根据具体任务选择合适的策略。
  • 可扩展性 :适用于动态图和超大规模图。

(2)未来展望

尽管GraphSAGE已经取得了显著的成果,但仍有许多值得探索的方向:

  • 更高效的采样策略 :如何设计更智能的采样方法,进一步提升模型性能?
  • 跨领域应用 :如何将GNN应用于更多领域,例如健康估计、寿命预测、生物信息学、金融分析等?
  • 理论分析 :深入研究GNN的表达能力和泛化能力。

http://www.ppmy.cn/ops/163457.html

相关文章

React面试葵花宝典之二

36.Fiber的更新机制 React Fiber 更新机制详解 React Fiber 是 React 16 引入的核心架构重构,旨在解决可中断渲染和优先级调度问题,提升复杂应用的流畅性。其核心思想是将渲染过程拆分为可控制的工作单元,实现更细粒度的任务管理。以下是其…

神经网络:AI的网络神经

神经网络(Neural Networks)是深度学习的基础,是一种模仿生物神经系统结构和功能的计算模型。它由大量相互连接的节点(称为神经元)组成,能够通过学习数据中的模式来完成各种任务,如图像分类、语音…

Netty笔记3:NIO编程

Netty笔记1:线程模型 Netty笔记2:零拷贝 Netty笔记3:NIO编程 Netty笔记4:Epoll Netty笔记5:Netty开发实例 Netty笔记6:Netty组件 Netty笔记7:ChannelPromise通知处理 Netty笔记8&#xf…

深入探索Python机器学习算法:监督学习(线性回归,逻辑回归,决策树与随机森林,支持向量机,K近邻算法)

文章目录 深入探索Python机器学习算法:监督学习一、线性回归二、逻辑回归三、决策树与随机森林四、支持向量机五、K近邻算法 深入探索Python机器学习算法:监督学习 在机器学习领域,Python凭借其丰富的库和简洁的语法成为了众多数据科学家和机…

【Go】Go viper 配置模块

1. 配置相关概念 在项目开发过程中,一旦涉及到与第三方中间件打交道就不可避免的需要填写一些配置信息,例如 MySQL 的连接信息、Redis 的连接信息。如果这些配置都采用硬编码的方式无疑是一种不优雅的做法,有以下缺陷: 不同环境…

Python----Python爬虫(多线程,多进程,协程爬虫)

注意: 该代码爬取小说不久或许会失效,有时候该网站会被封禁,代码只供参考,不同小说不同网址会有差异 神印王座II皓月当空最新章节_神印王座II皓月当空全文免费阅读-笔趣阁 一、多线程爬虫 1.1、单线程爬虫的问题 爬虫通常被认为…

DeepSeek赋能Power BI:开启智能化数据分析新时代

在数据驱动决策的时代,数据分析工具的高效性与智能化程度成为决定企业竞争力的关键因素。Power BI作为一款功能强大的商业智能工具,深受广大数据分析师和企业用户的喜爱。而DeepSeek这一先进的人工智能技术的加入,更是为Power BI注入了新的活…

通过 Groq 后端加载Llama 模型,并调用Function call,也就是通过Groq 后端进行工具的绑定和调用

完整代码: import getpass import os from langchain.chat_models import init_chat_model from langchain_core.tools import tool from langchain_core.messages import HumanMessage, ToolMessage,SystemMessage# 如果没有设置 GROQ_API_KEY,则提示用…