pytorch图神经网络处理图结构数据

news/2025/2/2 2:49:07/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成,其中节点表示实体,边表示实体之间的关系或连接。GNNs 通过在图的结构上进行信息传递和节点嵌入(node embedding)来学习节点或图的特征表示。

GNN的关键思想是通过消息传递机制(message passing)更新每个节点的表示,通常是基于其邻居节点的特征信息。GNNs 可以广泛应用于许多领域,如社交网络分析、推荐系统、知识图谱、分子图表示等。

以下是GNN的基本组成部分和工作原理:

  1. 节点表示更新:每个节点的表示通过其邻居节点的表示进行更新。常见的做法是通过聚合邻居节点的特征,然后与节点本身的特征进行结合

GNN的变种

  1. GCN(Graph Convolutional Networks):一种基于图卷积的GNN,通过聚合邻居节点的特征来更新节点表示,适用于无向图。

  2. GraphSAGE(Graph Sample and Aggregation):通过随机采样邻居节点来提高计算效率,尤其适用于大规模图。

  3. GAT(Graph Attention Networks):引入了注意力机制,使得不同邻居对节点更新的贡献不同,能够动态调整每个邻居的权重。

  4. Graph Isomorphism Network (GIN):通过强大的表征能力增强了图的判别性。

GNN的应用

  • 社交网络分析:预测用户之间的关系或用户的兴趣。
  • 推荐系统:基于用户和物品之间的图结构进行个性化推荐。
  • 生物信息学:如分子图表示,用于药物发现、蛋白质结构预测等。
  • 图像分割与语义分析:在视觉任务中处理图形数据,捕捉图像之间的关系。

例子:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点,每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签(0 或 1)# 2. 生成图结构(邻接关系)
edge_index = []
for i in range(num_nodes):for j in range(i + 1, num_nodes):if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):edge_index.append([i, j])edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练,20% 测试
test_mask = ~train_mask# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 16)self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同self.conv4 = GCNConv(16, 2)  # 输出类别数 2def forward(self, data):x, edge_index = data.x, data.edge_indexx = F.relu(self.conv1(x, edge_index))x = F.relu(self.conv2(x, edge_index))x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接,维度一致x = self.conv4(x, edge_index)return F.log_softmax(x, dim=1)  # 输出对数概率# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减data = data.to(device)
num_epochs = 2000  # 增加训练轮数for epoch in range(num_epochs):model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()scheduler.step()  # 逐步降低学习率if epoch % 200 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()


http://www.ppmy.cn/news/1568579.html

相关文章

29. C语言 可变参数详解

本章目录: 前言可变参数的基本概念可变参数的工作原理如何使用可变参数 示例&#xff1a;计算多个整数的平均值解析&#xff1a; 更复杂的可变参数示例&#xff1a;打印可变数量的字符串解析&#xff1a; 总结 前言 在C语言中&#xff0c;函数参数的数量通常是固定的&#xff…

题海拾贝:力扣 622.设计循环队列

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《数据结构与算法之美》、《题海拾贝》 欢迎点赞&#xff0c;关注&#xff01; 1、题…

RAG:实现基于本地知识库结合大模型生成(LangChain4j快速入门#1)

引言 ⭐Tips&#xff1a; 你可以循序渐进从头看下去也可以选择直接跳到后面(快速入门)看代码和结果演示 场景解释以及适用场景 当我想让大模型能基于我私有化的一些本地知识进行回答&#xff0c;定制化特殊场景模型的时候&#xff0c;就可以用到这种方法。 示例1&#xff1a;…

在5G网络中使用IEEE 1588实现保持时间同步

本文主要探讨了在电信网络中实现保持时间同步&#xff08;holdover&#xff09;的不同方法。 文档讨论了保持时间同步的作用&#xff0c;以及它从传统SONET/SDH网络到现代5G移动通信网络的演变。传统SONET/SDH网络依赖于频率同步&#xff0c;而现代5G移动通信则依赖于使用IEEE…

单片机基础模块学习——DS18B20温度传感器芯片

不知道该往哪走的时候&#xff0c;就往前走。 一、DS18B20芯片原理图 该芯片共有三个引脚&#xff0c;分别为 GND——接地引脚DQ——数据通信引脚VDD——正电源 数据通信用到的是1-Wier协议 优点&#xff1a;占用端口少&#xff0c;电路设计方便 同时该协议要求通过上拉电阻…

第P7周-Pytorch实现马铃薯病害识别(VGG16复现)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 目标 马铃薯病害数据集&#xff0c;该数据集包含表现出各种疾病的马铃薯植物的高分辨率图像&#xff0c;包括早期疫病、晚期疫病和健康叶子。它旨在帮助开发和…

SAP SD学习笔记27 - 请求计划(开票计划)之1 - 定期请求(定期开票)

上两章讲了贩卖契约&#xff08;框架协议&#xff09;的概要&#xff0c;以及贩卖契约中最为常用的 基本契约 - 数量契约和金额契约。 SAP SD学习笔记26 - 贩卖契约(框架协议)的概要&#xff0c;基本契约 - 数量契约_sap 框架协议-CSDN博客 SAP SD学习笔记27 - 贩卖契约(框架…

什么是波士顿矩阵,怎么制作?AI工具一键生成战略分析图!

当今商业环境瞬息万变&#xff0c;每个企业都面临着越来越多的挑战与机遇。如何科学合理地进行战略管理&#xff0c;成为了每个企业决策者必须直面的重要课题。 在众多战略管理框架中&#xff0c;波士顿矩阵作为一种经典的战略管理工具&#xff0c;因其简洁明了的分析方式而广…