6.如何用CSV文件生成异构图数据集

news/2024/11/15 0:51:04/

       我们将使用GroupLens研究小组收集的MovieLens数据集
       这个数据集描述了MovieLens的五星评级和标记活动。该数据集包含来自600多名用户的9000多部电影的约10万个评分。我们将使用该数据集生成两种节点类型,分别保存电影和用户的数据,以及一种连接用户和电影的边类型,表示用户对特定电影的评分关系。
       首先,我们将数据集下载到任意文件夹(在本例中为当前目录):

from torch_geometric.data import download_url, extract_zipurl = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
extract_zip(download_url(url, '.'), '.')movies_path = './ml-latest-small/movies.csv'
ratings_path = './ml-latest-small/ratings.csv'

打开数据集,就可以看到以下一些文件:
在这里插入图片描述

import pandas as pdprint(pd.read_csv(movies_path).head()) # DataFrame对象,默认显示前5行
print(pd.read_csv(ratings_path).head())

在这里插入图片描述
       为了用PyG数据格式表示这些数据,我们首先定义了一个方法load_node_csv(),该方法读取*.csv文件并返回形状为[num_nodes,num_features]的节点级特征表示x:

import torchdef load_node_csv(path, index_col, encoders=None, **kwargs): # **kwargs用于在函数定义中接收任意数量的关键字参数,是一个字典df = pd.read_csv(path, index_col=index_col, **kwargs) # 读取*.csvmapping = {index: i for i, index in enumerate(df.index.unique())} # 将索引映射成连续值x = Noneif encoders is not None:xs = [encoder(df[col]) for col, encoder in encoders.items()]x = torch.cat(xs, dim=-1)return x, mapping

在这里插入图片描述
       这里,load_node_csv()从路径读取*.csv文件,并创建一个字典映射,将其索引列映射到范围{0,…,num_rows-1}中的连续值。这是必要的,因为我们希望我们的最终数据表示尽可能紧凑,例如,第一行中的电影表示应该可以通过x[0]访问。

from sentence_transformers import SentenceTransformer
class SequenceEncoder:def __init__(self, model_name='all-MiniLM-L6-v2', device=None):self.device = deviceself.model = SentenceTransformer(model_name, device=device)@torch.no_grad()def __call__(self, df):x = self.model.encode(df.values, show_progress_bar=True,convert_to_tensor=True, device=self.device)print(x.shape)return x.cpu()

       SequenceEncoder类加载一个由model_name给定的预先训练的NLP模型,并使用它将字符串列表编码为形状为[num_strings,embedding_dim]的PyTorch张量。我们可以使用此SequenceEncodermovies.csv文件的标题进行编码。

       以类似的方式,我们可以创建另一个编码器,将电影类型转换为分类标签。为此,我们首先需要找到数据中存在的所有电影类型,创建shape[num_movies,num_genres]的特征表示x,并在类型j存在于电影i中的情况下将1分配给x[i,j]:

class GenresEncoder:def __init__(self, sep='|'):self.sep = sepdef __call__(self, df):genres = set(g for col in df.values for g in col.split(self.sep))mapping = {genre: i for i, genre in enumerate(genres)}x = torch.zeros(len(df), len(mapping))for i, col in enumerate(df.values):for genre in col.split(self.sep):x[i, mapping[genre]] = 1print(x.shape)return x

       有了这个,我们可以通过以下方式获得我们对电影的最终呈现:
在这里插入图片描述
       类似地,我们也可以使用load_node_csv()来获得从userId到连续值的用户映射。但是,此数据集中没有用户的其他特征信息。因此,我们没有定义任何编码器:
在这里插入图片描述

       这样,我们就可以初始化HeteroData对象,并将两种节点类型传递给它:

from torch_geometric.data import HeteroDatadata = HeteroData()data['user'].num_nodes = len(user_mapping)  # Users do not have any features.
data['movie'].x = movie_xprint(data)
print(movie_x.shape)

在这里插入图片描述
       由于用户没有任何节点级别的信息,我们只定义其节点数。因此,在异构图模型的训练过程中,我们可能需要通过torch.nn.Embedding以端到端的方式学习不同的用户嵌入。

       接下来,我们来看看根据用户的评分将他们与电影联系起来。为此,我们定义了一个方法load_edge_csv(),该方法从ratings.csv返回shape[2,num_ratings]的最终edge_index表示,以及原始*.csv文件中存在的任何其他功能:

def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,encoders=None, **kwargs):df = pd.read_csv(path, **kwargs)src = [src_mapping[index] for index in df[src_index_col]]dst = [dst_mapping[index] for index in df[dst_index_col]]#print(len(src))#print(len(dst))edge_index = torch.tensor([src, dst])edge_attr = Noneif encoders is not None:edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]edge_attr = torch.cat(edge_attrs, dim=-1)#print(edge_attr.shape)return edge_index, edge_attr

       这里,src_index_coldst_index_col分别定义源节点和目标节点的索引列。我们进一步利用节点级映射src_mappingdst_mapping来确保原始索引在我们的最终表示中被映射到正确的连续索引。

       对于文件中定义的每条边,它会在src_mappingdst_mapping中查找正向索引,并适当地移动数据。

       load_node_csv()类似,编码器用于返回额外的边特征信息。例如,为了从ratings.csv中的rating列加载ratings,我们可以定义一个IdentityEncoder,它只需将浮点值列表转换为PyTorch张量:

class IdentityEncoder:def __init__(self, dtype=None):self.dtype = dtypedef __call__(self, df):return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)

       这样,我们就可以完成我们的HeteroData对象了:

edge_index, edge_label = load_edge_csv(ratings_path,src_index_col='userId',src_mapping=user_mapping,dst_index_col='movieId',dst_mapping=movie_mapping,encoders={'rating': IdentityEncoder(dtype=torch.long)},
)data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_labelprint(data)

在这里插入图片描述
       该HeteroData对象是PyG中异构图的原生格式,可以用作异构图模型的输入。

本文内容参考:PyG官网
视频讲解:4.如何用CSV文件生成异构图数据集


http://www.ppmy.cn/news/1004436.html

相关文章

Shell脚本学习-Shell脚本框架

我们能写出.sh文件的脚本。已经觉得很好了。但是我们还需要进一步学习脚本框架的概念。 1、Shell脚本(模块)高级命名规则: 1)常规Shell脚本:chang.sh、test.sh等 2)模块的启动和停止统一命名为&#xff1…

Qt5.13引入QtWebApp的模块后报错: error C2440: “reinterpret_cast”: 无法从“int”转换为“quintptr”

1、开发环境 Win10-64 qt5.13 msvc2015-64bit-release 2、报错 新建一个demo工程。 引入QtWebApp的httpserver、logging、templateengine三个模块后。 直接运行,,此时报错如下: E:\Qt5.13.1\install\5.13.1\msvc2015_64\include\QtCore…

互联网——根服务器

说明 根服务器是互联网域名系统(DNS)中最高级别的服务器之一。它们负责管理整个DNS系统的顶级域名空间,例如.com、.org和.net等。 根服务器的主要功能是将用户的DNS查询转发到适当的顶级域名服务器。当用户在浏览器中输入一个域名&#xff…

JVM基础篇-本地方法栈与堆

JVM基础篇-本地方法栈与堆 本地方法栈 什么是本地方法? 本地方法即那些不是由java层面实现的方法,而是由c/c实现交给java层面进行调用,这些方法在java中使用native关键字标识 public native int hashCode()本地方法栈的作用? 为本地方法提供内存空…

Centos7 下 部署开源tesseract-ocr完整教程

Centos 7 下部署 tesseract5 我的 Centos7 是一个干净的系统,另外下述操作步骤亲测. 参考博客 http://www.nanstar.top/p/wiki_1649411481701https://segmentfault.com/a/1190000041832780 相关资源下载地址 https://download.csdn.net/download/qq_33547169/881…

Intellij IDEA运行报Command line is too long的解决办法

想哭,vue前端运行起来,对应的后端也得起服务。 后端出的这个bug,下面的博客写的第二种方法,完整截图是下面这个。 ​​​​​​​​​​​​​​​​​​​​Intellij IDEA运行报Command line is too long的解决办法 - 知乎 (zh…

SpringBoot项目配置多数据源实现查询

一.所需依赖 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><…

装修小程序,开启装修公司智能化服务的新时代

随着数字化时代的来临&#xff0c;装修小程序成为提升服务质量和效率的关键工具。装修小程序旨在为装修公司提供数字化赋能、提高客户满意度的智慧装修平台。通过装修小程序&#xff0c;装修公司能够与客户进行在线沟通、展示设计方案、提高服务满意度等操作。 装修小程序的好处…