PyTorch Geometric(PyG)机器学习实战

ops/2025/2/7 4:01:59/

PyTorch Geometric(PyG)机器学习实战

在图神经网络(GNN)的研究和应用中,PyTorch Geometric(PyG)作为一个基于PyTorch的库,提供了高效的图数据处理和模型构建功能。
本文将通过一个节点分类任务,演示如何使用PyG进行机器学习实战。

1. 环境准备

首先,确保已安装PyTorch和PyG。可以使用以下命令进行安装:

pip install torch
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric2. 导入必要的库import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv3. 加载数据集我们使用PyG自带的Planetoid数据集,这里以Cora数据集为例。dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]4. 定义GCN模型我们将构建一个包含两层图卷积层(GCNConv)的模型。class GCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)5. 初始化模型和优化器model = GCN(in_channels=dataset.num_node_features,hidden_channels=16,out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)6. 训练模型def train():model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()for epoch in range(200):loss = train()if epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')7. 测试模型def test():model.eval()out = model(data)pred = out.argmax(dim=1)correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accaccuracy = test()
print(f'Accuracy: {accuracy:.4f}')8. 完整代码import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加载数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定义GCN模型
class GCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 初始化模型和优化器
model = GCN(in_channels=dataset.num_node_features,hidden_channels=16,out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()for epoch in range(200):loss = train()if epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')# 测试模型
def test():model.eval()out = model(data)pred = out.argmax(dim=1)correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accaccuracy = test()
print(f'Accuracy: {accuracy:.4f}')
'''9. 结果分析通过上述步骤,我们成功地使用PyG构建并训练了一个图卷积神经网络(GCN)模型。
在训练过程中,模型逐步学习图结构数据的特征,最终在测试集上取得了较好的分类准确率。
这展示了PyG在图数据处理和模型构建方面的强大功能。10. 参考文献• PyTorch Geometric官方文档
• PyTorch Geometric教程

通过本教程,您可以了解如何使用PyG进行图神经网络的构建和训练,为进一步的研究和应用奠定基础。


http://www.ppmy.cn/ops/156335.html

相关文章

7.抽象工厂(Abstract Factory)

抽象工厂与工厂方法极其类似,都是绕开new的,但是有些许不同。 动机 在软件系统中,经常面临着“一系列相互依赖的对象”的创建工作;同时,由于需求的变化,往往存在更多系列对象的创建工作。 假设案例 假设…

基于Java、SSM、HTML、Vue在线视频教学网课管理系统设计

摘要 随着互联网技术的飞速发展,在线教育市场呈现出蓬勃的发展态势。本论文聚焦于在线视频教学网课管理系统的设计与实现,该系统基于Java语言,运用SSM(Spring SpringMVC MyBatis)框架构建后端服务,结合H…

Deep Sleep 96小时:一场没有硝烟的科技保卫战

2025年1月28日凌晨3点,当大多数人还沉浸在梦乡时,一场没有硝烟的战争悄然打响。代号“Deep Sleep”的服务器突遭海量数据洪流冲击,警报声响彻机房,一场针对中国关键信息基础设施的网络攻击来势汹汹! 面对美国发起的这场…

VSCode中使用EmmyLua插件对Unity的tolua断点调试

一.VSCode中搜索安装EmmyLua插件 二.创建和编辑launch.json文件 初始的launch.json是这样的 手动编辑加上一段内容如下图所示: 三.启动调试模式,并选择附加的进程

梯度提升用于高效的分类与回归

人工智能例子汇总:AI常见的算法和例子-CSDN博客 使用 决策树(Decision Tree) 实现 梯度提升(Gradient Boosting) 主要是模拟 GBDT(Gradient Boosting Decision Trees) 的原理,即&a…

【Elasticsearch】geohex grid聚合

在 Elasticsearch 中,地理边界过滤是一种用于筛选地理数据的技术,它可以根据指定的地理边界形状(如矩形、多边形等)来过滤符合条件的文档。这种方法在地理空间数据分析中非常有用,尤其是在需要将数据限制在特定地理区域…

C# Monitor类 使用详解

总目录 前言 在 C# 中,Monitor 类是一个用于实现线程同步的重要工具,它提供了一种机制来确保同一时间只有一个线程可以访问特定的代码块或资源,从而避免多线程环境下的数据竞争和不一致问题。下面将对 Monitor 类进行详细介绍。 一、Monitor…

day38|leetcode 322零钱兑换,279.完全平方数,139.单词拆分

322. 零钱兑换 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。 计算并返回可以凑成总金额所需的 最少的硬币个数 。如果没有任何一种硬币组合能组成总金额,返回 -1 。 你可以认为每种硬币的数量是…