我们将使用GroupLens研究小组收集的MovieLens数据集。
这个数据集描述了MovieLens的五星评级和标记活动。该数据集包含来自600多名用户的9000多部电影的约10万个评分。我们将使用该数据集生成两种节点类型,分别保存电影和用户的数据,以及一种连接用户和电影的边类型,表示用户对特定电影的评分关系。
首先,我们将数据集下载到任意文件夹(在本例中为当前目录):
from torch_geometric.data import download_url, extract_zipurl = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
extract_zip(download_url(url, '.'), '.')movies_path = './ml-latest-small/movies.csv'
ratings_path = './ml-latest-small/ratings.csv'
打开数据集,就可以看到以下一些文件:
import pandas as pdprint(pd.read_csv(movies_path).head()) # DataFrame对象,默认显示前5行
print(pd.read_csv(ratings_path).head())
为了用PyG数据格式表示这些数据,我们首先定义了一个方法load_node_csv()
,该方法读取*.csv文件并返回形状为[num_nodes,num_features]
的节点级特征表示x:
import torchdef load_node_csv(path, index_col, encoders=None, **kwargs): # **kwargs用于在函数定义中接收任意数量的关键字参数,是一个字典df = pd.read_csv(path, index_col=index_col, **kwargs) # 读取*.csvmapping = {index: i for i, index in enumerate(df.index.unique())} # 将索引映射成连续值x = Noneif encoders is not None:xs = [encoder(df[col]) for col, encoder in encoders.items()]x = torch.cat(xs, dim=-1)return x, mapping
这里,load_node_csv()
从路径读取*.csv文件,并创建一个字典映射,将其索引列映射到范围{0,…,num_rows-1}中的连续值。这是必要的,因为我们希望我们的最终数据表示尽可能紧凑,例如,第一行中的电影表示应该可以通过x[0]访问。
from sentence_transformers import SentenceTransformer
class SequenceEncoder:def __init__(self, model_name='all-MiniLM-L6-v2', device=None):self.device = deviceself.model = SentenceTransformer(model_name, device=device)@torch.no_grad()def __call__(self, df):x = self.model.encode(df.values, show_progress_bar=True,convert_to_tensor=True, device=self.device)print(x.shape)return x.cpu()
SequenceEncoder
类加载一个由model_name给定的预先训练的NLP模型,并使用它将字符串列表编码为形状为[num_strings,embedding_dim]
的PyTorch张量。我们可以使用此SequenceEncoder
对movies.csv
文件的标题进行编码。
以类似的方式,我们可以创建另一个编码器,将电影类型转换为分类标签。为此,我们首先需要找到数据中存在的所有电影类型,创建shape[num_movies,num_genres]
的特征表示x,并在类型j存在于电影i中的情况下将1分配给x[i,j]:
class GenresEncoder:def __init__(self, sep='|'):self.sep = sepdef __call__(self, df):genres = set(g for col in df.values for g in col.split(self.sep))mapping = {genre: i for i, genre in enumerate(genres)}x = torch.zeros(len(df), len(mapping))for i, col in enumerate(df.values):for genre in col.split(self.sep):x[i, mapping[genre]] = 1print(x.shape)return x
有了这个,我们可以通过以下方式获得我们对电影的最终呈现:
类似地,我们也可以使用load_node_csv()来获得从userId到连续值的用户映射。但是,此数据集中没有用户的其他特征信息。因此,我们没有定义任何编码器:
这样,我们就可以初始化HeteroData对象,并将两种节点类型传递给它:
from torch_geometric.data import HeteroDatadata = HeteroData()data['user'].num_nodes = len(user_mapping) # Users do not have any features.
data['movie'].x = movie_xprint(data)
print(movie_x.shape)
由于用户没有任何节点级别的信息,我们只定义其节点数。因此,在异构图模型的训练过程中,我们可能需要通过torch.nn.Embedding
以端到端的方式学习不同的用户嵌入。
接下来,我们来看看根据用户的评分将他们与电影联系起来。为此,我们定义了一个方法load_edge_csv()
,该方法从ratings.csv
返回shape[2,num_ratings]
的最终edge_index
表示,以及原始*.csv文件中存在的任何其他功能:
def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,encoders=None, **kwargs):df = pd.read_csv(path, **kwargs)src = [src_mapping[index] for index in df[src_index_col]]dst = [dst_mapping[index] for index in df[dst_index_col]]#print(len(src))#print(len(dst))edge_index = torch.tensor([src, dst])edge_attr = Noneif encoders is not None:edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]edge_attr = torch.cat(edge_attrs, dim=-1)#print(edge_attr.shape)return edge_index, edge_attr
这里,src_index_col
和dst_index_col
分别定义源节点和目标节点的索引列。我们进一步利用节点级映射src_mapping
和dst_mapping
来确保原始索引在我们的最终表示中被映射到正确的连续索引。
对于文件中定义的每条边,它会在src_mapping
和dst_mapping
中查找正向索引,并适当地移动数据。
与load_node_csv()
类似,编码器用于返回额外的边特征信息。例如,为了从ratings.csv
中的rating列加载ratings,我们可以定义一个IdentityEncoder
,它只需将浮点值列表转换为PyTorch张量:
class IdentityEncoder:def __init__(self, dtype=None):self.dtype = dtypedef __call__(self, df):return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)
这样,我们就可以完成我们的HeteroData对象了:
edge_index, edge_label = load_edge_csv(ratings_path,src_index_col='userId',src_mapping=user_mapping,dst_index_col='movieId',dst_mapping=movie_mapping,encoders={'rating': IdentityEncoder(dtype=torch.long)},
)data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_labelprint(data)
该HeteroData对象是PyG中异构图的原生格式,可以用作异构图模型的输入。
本文内容参考:PyG官网
视频讲解:4.如何用CSV文件生成异构图数据集