RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

news/2024/10/22 19:27:57/

任务类型:回归
用途:在 `RootNeighboursDataset` 中,任务是给定一棵根树,预测根节点度数为6的邻居的特征平均值。因此,模型需要基于根节点的结构,找到度为6的邻居,并计算其特征的平均值。这属于回归问题,因为目标是预测连续值(特征的平均值)

from helpers.dataset_classes.root_neighbours_dataset import RootNeighboursDataset

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed = seedself.plot_flag = print_flagself.generator = torch.Generator().manual_seed(seed)self.constants_dict = self.initialize_constants()self._data = self.create_data()def get(self) -> Data:return self._datadef create_data(self) -> Data:# train, val, testdata_list = []for num in range(self.constants_dict['NUM_COMPONENTS']):data_list.append(self.generate_component())return Batch.from_data_list(data_list)def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:num_nodes = sum(num_nodes_per_fold)train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)train_mask[0] = Trueval_mask[num_nodes_per_fold[0]] = Truetest_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = Truereturn train_mask, val_mask, test_maskdef generate_component(self) -> Data:data_per_fold, num_nodes_per_fold = [], []for fold_idx in range(3):data = self.generate_fold(eval=(fold_idx != 0))num_nodes_per_fold.append(data.x.shape[0])data_per_fold.append(data)train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)batch = Batch.from_data_list(data_per_fold)return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask,test_mask=test_mask)def initialize_constants(self) -> Dict[str, int]:return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5,'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}def generate_fold(self, eval: bool) -> Data:constant_dict = self.initialize_constants()MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\[constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS','MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORSadd_hubs = ADD_HUBS if eval else 0num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubsnum_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,),generator=self.generator).item()assert num_hubs <= num_1hop_neighborslist_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,),generator=self.generator).tolist()list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors# 2 hop edge indexnum_nodes = 1  # root node is 0idx_1hop_neighbors = []list_edge_index = []for num_2hop_neighbors in list_num_2hop_neighbors:idx_1hop_neighbors.append(num_nodes)if num_2hop_neighbors > 0:clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])# clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).Tlist_edge_index.append(clique_edge_index + num_nodes)num_nodes += num_2hop_neighbors + 1# 1 hop edge indexidx_0hop = torch.tensor([0] * num_1hop_neighbors)idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)hubs = idx_1hop_neighbors[:num_hubs]list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))edge_index = torch.cat(list_edge_index, dim=1)# undirectedge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)# featuresx = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2# labelsy = torch.zeros_like(x)y[0] = torch.mean(x[hubs], dim=0)return Data(x=x, edge_index=edge_index, y=y)if __name__ == '__main__':data = RootNeighboursDataset(seed=0, print_flag=True)

这个 RootNeighboursDataset通过随机生成的树状图数据来模拟一种节点关系,并基于图结构生成特征和标签。代码使用了 PyTorchPyTorch Geometric 的功能来处理图数据。下面逐块详细解释该代码实现:

1. RootNeighboursDataset 类构造器

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed &#

http://www.ppmy.cn/news/1541131.html

相关文章

【人工智能】Transformers之Pipeline(二十):令牌分类(token-classification)

目录 一、引言 二、令牌分类&#xff08;token-classification&#xff09; 2.1 概述 2.2 Facebook AI/XLM-RoBERTa 2.3 pipeline参数 2.3.1 pipeline对象实例化参数 2.3.2 pipeline对象使用参数 2.3.3 pipeline返回参数 ​​​​​​​​​​​​​​ 2.4 pipeline…

第7章 网络请求和状态管理

一、Axios 1 Axios概述 Axios是一个基于Promise的HTTP库&#xff0c;可以发送get、post等请求&#xff0c;它作用于浏览器和Node.js中。当运行在浏览器时&#xff0c;使用XMLHttpRequest接口发送请求&#xff1b;当运行在Node.js时&#xff0c;使用HTTP对象发送请求。 Axios的…

使用HIP和OpenMP卸载的Jacobi求解器

Jacobi Solver with HIP and OpenMP offloading — ROCm Blogs (amd.com) 作者&#xff1a;Asitav Mishra, Rajat Arora, Justin Chang 发布日期&#xff1a;2023年9月15日 Jacobi方法作为求解偏微分方程&#xff08;PDE&#xff09;的基本迭代线性求解器在高性能计算&#xff…

每日一练 —— set习题

1. 两个数组的交集 题目链接&#xff1a;349. 两个数组的交集 - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/intersection-of-two-arrays/description/ 这题使用set&#xff0c;因为set具有排序和去重的特性 思路&#xff1a; 1.两个值相等就是交集 2.…

【前端】如何制作一个自己的网页(14)

当我们还需要对网页中的内容进行局部样式的修改。这时候&#xff0c;就需要用到HTML中的重要元素&#xff1a;span。 span是一个行内元素&#xff0c;可以对HTML文档中的内容进行局部布局。 如图&#xff0c;我们给标题和段落元素的部分内容设置了各种样式。 接下来&#xff0…

【第一章·为什么要学习编程】

目录 1.1 学习编程的热潮 1.1.1 席卷全球的“编程一小时” 1.1.2 资本汹涌的少儿编程 1.1.3 “再不学编程就晚了” 1.2 为什么要学编程 1.3 什么是“编程” 1.4 怎么学编程 1.4.1 一切都是计算 1.4.2 学编程不是学语法 1.4.3 动手&#xff0c;动手&#xff0c;再动手…

1282:最大子矩阵

题目&#xff1a; 已知矩阵的大小定义为矩阵中所有元素的和。给定一个矩阵&#xff0c;你的任务是找到最大的非空(大小至少是1 1)子矩阵。 比如&#xff0c;如下4 4的矩阵 0 -2 -7 0 9 2 -6 2 -4 1 -4 1 -1 8 0 -2 的最大子矩阵是 9 2 -4 1 -1 8 这个子矩阵的大小是15。 …

代码随想录算法训练营Day08 | 344.反转字符串、541. 反转字符串II、卡码网:54.替换数字

文章目录 344.反转字符串思路与重点 541. 反转字符串II思路与重点 卡码网&#xff1a;54.替换数字思路与重点 344.反转字符串 题目链接&#xff1a;344. 反转字符串 - 力扣&#xff08;LeetCode&#xff09;讲解链接&#xff1a;代码随想录 (programmercarl.com)状态&#xff…