目录
优化版,去重相似度
topN 欧式距离版
没有去重复,
优化版,去重相似度
import torch
import torch.nn.functional as F
torch.manual_seed(42)
# 假设 10 条数据,每条数据的特征维度是 128
data = torch.randn(10, 128)# 计算所有数据对之间的余弦相似度
cosine_similarities = F.cosine_similarity(data.unsqueeze(0), data.unsqueeze(1), dim=2)# 通过设置对角线为负无穷,排除自身相似度
cosine_similarities.fill_diagonal_(-float('inf'))# 生成上三角掩码(i < j 的位置为True)
mask = torch.triu(torch.ones_like(cosine_similarities, dtype=torch.bool), diagonal=1)# 过滤掉下三角和对角线,仅保留