pytorch图神经网络处理图结构数据

embedded/2025/2/5 21:11:37/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成,其中节点表示实体,边表示实体之间的关系或连接。GNNs 通过在图的结构上进行信息传递和节点嵌入(node embedding)来学习节点或图的特征表示。

GNN的关键思想是通过消息传递机制(message passing)更新每个节点的表示,通常是基于其邻居节点的特征信息。GNNs 可以广泛应用于许多领域,如社交网络分析、推荐系统、知识图谱、分子图表示等。

以下是GNN的基本组成部分和工作原理:

  1. 节点表示更新:每个节点的表示通过其邻居节点的表示进行更新。常见的做法是通过聚合邻居节点的特征,然后与节点本身的特征进行结合

GNN的变种

  1. GCN(Graph Convolutional Networks):一种基于图卷积的GNN,通过聚合邻居节点的特征来更新节点表示,适用于无向图。

  2. GraphSAGE(Graph Sample and Aggregation):通过随机采样邻居节点来提高计算效率,尤其适用于大规模图。

  3. GAT(Graph Attention Networks):引入了注意力机制,使得不同邻居对节点更新的贡献不同,能够动态调整每个邻居的权重。

  4. Graph Isomorphism Network (GIN):通过强大的表征能力增强了图的判别性。

GNN的应用

  • 社交网络分析:预测用户之间的关系或用户的兴趣。
  • 推荐系统:基于用户和物品之间的图结构进行个性化推荐。
  • 生物信息学:如分子图表示,用于药物发现、蛋白质结构预测等。
  • 图像分割与语义分析:在视觉任务中处理图形数据,捕捉图像之间的关系。

例子:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点,每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签(0 或 1)# 2. 生成图结构(邻接关系)
edge_index = []
for i in range(num_nodes):for j in range(i + 1, num_nodes):if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):edge_index.append([i, j])edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练,20% 测试
test_mask = ~train_mask# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 16)self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同self.conv4 = GCNConv(16, 2)  # 输出类别数 2def forward(self, data):x, edge_index = data.x, data.edge_indexx = F.relu(self.conv1(x, edge_index))x = F.relu(self.conv2(x, edge_index))x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接,维度一致x = self.conv4(x, edge_index)return F.log_softmax(x, dim=1)  # 输出对数概率# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减data = data.to(device)
num_epochs = 2000  # 增加训练轮数for epoch in range(num_epochs):model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()scheduler.step()  # 逐步降低学习率if epoch % 200 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()


http://www.ppmy.cn/embedded/159849.html

相关文章

Python从0到100(八十六):神经网络-ShuffleNet通道混合轻量级网络的深入介绍

前言&#xff1a; 零基础学Python&#xff1a;Python从0到100最新最全教程。 想做这件事情很久了&#xff0c;这次我更新了自己所写过的所有博客&#xff0c;汇集成了Python从0到100&#xff0c;共一百节课&#xff0c;帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

小程序越来越智能化,作为设计师要如何进行创新设计

一、用户体验至上 &#xff08;一&#xff09;简洁高效的界面设计 小程序的特点之一是轻便快捷&#xff0c;用户期望能够在最短的时间内找到所需功能并完成操作。因此&#xff0c;设计师应致力于打造简洁高效的界面。避免过多的装饰元素和复杂的布局&#xff0c;采用清晰的导航…

深度学习 Pytorch 深层神经网络

在之前已经学习了三种单层神经网络&#xff0c;分别为实现线性方程的回归网络&#xff0c;实现二分类的逻辑回归&#xff08;二分类网络&#xff09;&#xff0c;以及实现多分类的softmax回归&#xff08;多分类网络&#xff09;。从本节开始&#xff0c;我们将从单层神经网络展…

网络安全攻防实战:从基础防护到高级对抗

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 引言 在信息化时代&#xff0c;网络安全已经成为企业、政府和个人必须重视的问题。从数据泄露到勒索软件攻击&#xff0c;每一次…

openssl 静态编译

1. 下载 openssl 各版本下载 https://openssl-library.org/source/old/index.html 2. 静态编译 ./config -fPIC no-shared make -j 4编译后的静态文件在 [chenlocalhost openssl-1.1.1g]$ ls | grep \.a libcrypto.a libssl.a编译后的执行文件在 ./apps/openssl# eg: [chen…

FPGA| 使用Quartus II报错Top-level design entity ““ is undefined

1、使用FPGA准备点亮LED测试下板子&#xff0c;发现这个报错Error (12007): Top-level design entity "LEDLED" is undefined 工程如上图 报错如下图 2、分析到原因是因为工程名称和顶层模块里面的module名称不一样导致 解决办法&#xff1a;修改module名称和顶层模…

Nginx的路径匹配规则 笔记250203

Nginx的路径匹配规则 Nginx 的路径匹配规则主要通过 location 指令实现&#xff0c;用于根据请求的 URI&#xff08;路径&#xff09;将请求路由到不同的处理逻辑。其匹配规则灵活且功能强大&#xff0c;但也需要谨慎配置以避免冲突。以下是 Nginx 路径匹配规则的详细解析&…

图形学笔记 - 5-光线追踪 - 辐射度量学

文章目录 辐射度量学辐射能和通量&#xff08;功率&#xff09;Radiant Energy and Flux (Power)辐射强度 Radiant Intensity辐照度Irradiance朗伯余弦定律Lambert’s Cosine Law Radiance辐亮度Incident Radiance入射辐亮度Exiting Radiance出射辐亮度 双向反射分布函数 Bidir…