pytorch图神经网络处理图结构数据

devtools/2025/2/7 0:20:10/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成,其中节点表示实体,边表示实体之间的关系或连接。GNNs 通过在图的结构上进行信息传递和节点嵌入(node embedding)来学习节点或图的特征表示。

GNN的关键思想是通过消息传递机制(message passing)更新每个节点的表示,通常是基于其邻居节点的特征信息。GNNs 可以广泛应用于许多领域,如社交网络分析、推荐系统、知识图谱、分子图表示等。

以下是GNN的基本组成部分和工作原理:

  1. 节点表示更新:每个节点的表示通过其邻居节点的表示进行更新。常见的做法是通过聚合邻居节点的特征,然后与节点本身的特征进行结合

GNN的变种

  1. GCN(Graph Convolutional Networks):一种基于图卷积的GNN,通过聚合邻居节点的特征来更新节点表示,适用于无向图。

  2. GraphSAGE(Graph Sample and Aggregation):通过随机采样邻居节点来提高计算效率,尤其适用于大规模图。

  3. GAT(Graph Attention Networks):引入了注意力机制,使得不同邻居对节点更新的贡献不同,能够动态调整每个邻居的权重。

  4. Graph Isomorphism Network (GIN):通过强大的表征能力增强了图的判别性。

GNN的应用

  • 社交网络分析:预测用户之间的关系或用户的兴趣。
  • 推荐系统:基于用户和物品之间的图结构进行个性化推荐。
  • 生物信息学:如分子图表示,用于药物发现、蛋白质结构预测等。
  • 图像分割与语义分析:在视觉任务中处理图形数据,捕捉图像之间的关系。

例子:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点,每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签(0 或 1)# 2. 生成图结构(邻接关系)
edge_index = []
for i in range(num_nodes):for j in range(i + 1, num_nodes):if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):edge_index.append([i, j])edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练,20% 测试
test_mask = ~train_mask# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 16)self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同self.conv4 = GCNConv(16, 2)  # 输出类别数 2def forward(self, data):x, edge_index = data.x, data.edge_indexx = F.relu(self.conv1(x, edge_index))x = F.relu(self.conv2(x, edge_index))x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接,维度一致x = self.conv4(x, edge_index)return F.log_softmax(x, dim=1)  # 输出对数概率# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减data = data.to(device)
num_epochs = 2000  # 增加训练轮数for epoch in range(num_epochs):model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()scheduler.step()  # 逐步降低学习率if epoch % 200 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()


http://www.ppmy.cn/devtools/156655.html

相关文章

TCP连接管理与UDP协议IP协议与ethernet协议

SEO Meta Description: 深入解析TCP连接管理、UDP协议、IP协议与Ethernet协议的工作原理及其在网络通信中的应用&#xff0c;全面了解各协议的功能与区别。 介绍 网络通信依赖于一系列协议来确保数据的可靠传输和高效处理。本文将详细介绍TCP连接管理、UDP协议、IP协议和Ethe…

深入解析:如何获取商品 SKU 详细信息

在电商领域&#xff0c;SKU&#xff08;Stock Keeping Unit&#xff0c;库存进出计量的基本单元&#xff09;是商品管理中的一个重要概念。每个 SKU 都代表了一个具体的产品变体&#xff0c;例如不同的颜色、尺寸或配置。获取商品的 SKU 详细信息对于商家优化库存管理、提高运营…

ChatGPT怎么回事?

纯属发现&#xff0c;调侃一下~ 这段时间deepseek不是特别火吗&#xff0c;尤其是它的推理功能&#xff0c;突发奇想&#xff0c;想用deepseek回答一些问题&#xff0c;回答一个问题之后就回复服务器繁忙&#xff08;估计还在被攻击吧~_~&#xff09; 然后就转向了GPT&#xf…

【自学笔记】Git的重点知识点-持续更新

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 Git基础知识Git高级操作与概念Git常用命令 总结 Git基础知识 Git简介 Git是一种分布式版本控制系统&#xff0c;用于记录文件内容的改动&#xff0c;便于开发者追踪…

优选算法合集————双指针(专题二)

好久都没给大家带来算法专题啦&#xff0c;今天给大家带来滑动窗口专题的训练 题目一&#xff1a;长度最小的子数组 题目描述&#xff1a; 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其和 ≥ target 的长度最小的 连续子数组 [numsl, numsl1, …

洛谷P2638 安全系统

安全系统 题目描述 特斯拉公司的六位密码被轻松破解后&#xff0c;引发了人们对电动车的安全性能的怀疑。李华听闻后&#xff0c;自己设计了一套密码&#xff1a; 假设安全系统中有 n n n 个储存区&#xff0c;每个储存区最多能存储存 2 2 2 种种类不同的信号&#xff08;…

实现一个 LRU 风格的缓存类

实现一个缓存类 需求描述豆包解决思路&#xff1a;实现代码&#xff1a;优化11. std::list::remove 的时间复杂度问题2. 代码复用优化后的代码优化说明 优化21. 边界条件检查2. 异常处理3. 代码封装性4. 线程安全优化后的代码示例优化说明 DeepSeek&#xff08;深度思考R1&…

攻防世界 php2

你能验证这个网站吗&#xff1f; 根据提示是PHP&#xff0c;我们知道PHP代码在服务器端执行,可以连接数据库&#xff0c;查询数据&#xff0c;并根据查询结果动态生成网页内容。 PHP代码通常是保存在以.PHP为扩展名的文件上,这些文件存储在web 服务器的文档根目录中&#xff…