PyTorch Geometric(PyG)机器学习实战

embedded/2025/2/9 4:49:31/

PyTorch Geometric(PyG)机器学习实战

在图神经网络(GNN)的研究和应用中,PyTorch Geometric(PyG)作为一个基于PyTorch的库,提供了高效的图数据处理和模型构建功能。
本文将通过一个节点分类任务,演示如何使用PyG进行机器学习实战。

1. 环境准备

首先,确保已安装PyTorch和PyG。可以使用以下命令进行安装:

pip install torch
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric2. 导入必要的库import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv3. 加载数据集我们使用PyG自带的Planetoid数据集,这里以Cora数据集为例。dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]4. 定义GCN模型我们将构建一个包含两层图卷积层(GCNConv)的模型。class GCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)5. 初始化模型和优化器model = GCN(in_channels=dataset.num_node_features,hidden_channels=16,out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)6. 训练模型def train():model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()for epoch in range(200):loss = train()if epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')7. 测试模型def test():model.eval()out = model(data)pred = out.argmax(dim=1)correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accaccuracy = test()
print(f'Accuracy: {accuracy:.4f}')8. 完整代码import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加载数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定义GCN模型
class GCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 初始化模型和优化器
model = GCN(in_channels=dataset.num_node_features,hidden_channels=16,out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()for epoch in range(200):loss = train()if epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')# 测试模型
def test():model.eval()out = model(data)pred = out.argmax(dim=1)correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accaccuracy = test()
print(f'Accuracy: {accuracy:.4f}')
'''9. 结果分析通过上述步骤,我们成功地使用PyG构建并训练了一个图卷积神经网络(GCN)模型。
在训练过程中,模型逐步学习图结构数据的特征,最终在测试集上取得了较好的分类准确率。
这展示了PyG在图数据处理和模型构建方面的强大功能。10. 参考文献• PyTorch Geometric官方文档
• PyTorch Geometric教程

通过本教程,您可以了解如何使用PyG进行图神经网络的构建和训练,为进一步的研究和应用奠定基础。


http://www.ppmy.cn/embedded/160713.html

相关文章

【R语言】写入数据

一、写入R语言系统格式的数据 R语言自带.RData和.rds两种数据格式。 通过使用save()函数和saveRDS()函数将R语言数据处理结果保存为此类数据。 # 将iris数据集保存为RData文件 save(listc("iris"), file"iris.RData") # 将iris数据集保存为rds文件 save…

python-leetcode-被围绕的区域

130. 被围绕的区域 - 力扣(LeetCode) class Solution:def solve(self, board: List[List[str]]) -> None:"""Do not return anything, modify board in-place instead."""if not board or not board[0]:returnrows, co…

【自学笔记】Python的基础知识点总览-持续更新

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 Python基础知识总览1. Python简介2. 安装与环境配置3. 基本语法3.1 变量与数据类型3.2 控制结构3.3 函数与模块3.4 文件操作 4. 面向对象编程(OOP&#…

力扣 无重复字符的最长子串

滑动窗口,双指针移动找集合类的元素。 题目 无重复,可想到hashset集,然后由题找最长子串,说明要处理左右边界,可以用双指针,右指针一直遍历,左指针看到重复就加一,这像是一个滑动窗…

【算法专场】分治(下)

目录 前言 归并排序 思想 912. 排序数组 算法思路 算法代码 LCR 170. 交易逆序对的总数 算法思路 算法代码 315. 计算右侧小于当前元素的个数 - 力扣(LeetCode) 算法思路 算法代码 493. 翻转对 算法思路 算法代码 好久不见~时隔多日&…

限流策略实战指南:从算法选择到阈值设置,打造高可用系统

前言 本文将深入探讨常见的限流算法及其适用场景,并详细解析基于 QPS 的限流方案。从如何设置合理的限流阈值,到请求被限流后的处理策略。 常见的限流算法 漏桶 核心原理 请求以任意速率进桶,以 恒定速率 出桶。若桶满则丢弃或排队等待适…

【3分钟极速部署】在本地快速部署deepseek

第一步,找到网站,下载: 首先找到Ollama , 根据自己的电脑下载对应的版本 。 我个人用的是Windows 我就先尝试用Windows版本了 ,文件不是很大,下载也比较的快 第二部就是安装了 : 安装完成后提示…

将Deepseek接入pycharm 进行AI编程

目录 专栏导读1、进入Deepseek开放平台创建 API key 2、调用 API代码 3、成功4、补充说明多轮对话 总结 专栏导读 🌸 欢迎来到Python办公自动化专栏—Python处理办公问题,解放您的双手 🏳️‍🌈 博客主页:请点击——…