sklearn 笔记 BallTree/KD Tree

news/2025/2/13 23:51:07/

由NearestNeighbors类包装

1 主要使用方法

sklearn.neighbors.BallTree(X, leaf_size=40, metric='minkowski', **kwargs)
X数据集中的点数
leaf_size改变 leaf_size 不会影响查询的结果,但可以显著影响查询的速度和构建树所需的内存
metric用于距离计算的度量。默认为 "minkowski"

2 主要方法

2.1 get_arrays

import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((10, 3))
tree = BallTree(X)                
tree.get_arrays()'''
(array([[0.90651098, 0.68471698, 0.6299996 ],[0.82751465, 0.31739009, 0.61572299],[0.22778906, 0.63614041, 0.73672184],[0.64655758, 0.9729849 , 0.68232389],[0.94992886, 0.72604933, 0.45649069],[0.34932115, 0.95985124, 0.41451989],[0.45131894, 0.21650206, 0.82466273],[0.87047096, 0.48403116, 0.58119046],[0.94468825, 0.14985636, 0.12132986],[0.62717326, 0.12924198, 0.23928098]]),array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64),array([(0, 10, 1, 0.61638879)],dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]),array([[[0.68012737, 0.52767645, 0.53022429]]]))
'''
  • 返回了4个数组
    • 第一个数组:原始数据点数组

    • 第二个数组:整数数组,代表每个点的索引

    • 第三个数组:结构化数组,包含了 BallTree 的内部树结构的信息

      • idx_startidx_end:定义了存储在当前节点的点的索引范围。
      • is_leaf:表明当前节点是否是叶节点。
      • radius:当前节点中所有点到节点中心点的最大距离
    • 第四个数组:树的每个节点的中心点

2.2 get_tree_stats

获取 BallTree 的状态信息:树的剪枝次数、叶节点的数量、分裂次数

2.3 query

查询树以找到 k 个最近邻居

query(X, k=1, return_distance=True, dualtree=False, breadth_first=False)
X要查询的点的数组
k

(int,默认为1)

要返回的最近邻居的数量

return_distance

(bool,默认为True)

如果为 True,返回一个包含距离和索引的元组 (d, i);

如果为 False,只返回数组 i

dualtree

(bool,默认为False):

如果为 True,使用双树形式进行查询:为查询点构建一个树,并使用这对树来高效地搜索这个空间当点的数量变得很大时,这可以带来更好的性能

breadth_first

(bool,默认为False)

如果为 True,则以广度优先的方式查询节点。否则,以深度优先的方式查询

sort_results

(bool,默认为True)

如果为 True,则在返回时对每个点的距离和索引进行排序,使得第一列包含最近的点

import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
tree = BallTree(X)                
tree.query(X[:3],k=3)
'''
(array([[0.        , 0.08335798, 0.15625817],[0.        , 0.06843236, 0.10825558],[0.        , 0.0968137 , 0.10245125]]),array([[ 0, 59, 88],[ 1, 70,  5],[ 2, 43, 20]], dtype=int64))
'''

2.4 query_radius

  • 进行半径查询的功能
  • 查询树,以找出在指定半径 r 内的邻居点
query_radius(X, r, return_distance=False, count_only=False, sort_results=False)
X要查询的点的数组
r

返回邻居的距离范围

r 可以是单个值,也可以是一个数组,形状为 x.shape[:-1],如果每个点需要不同的半径

return_distance

(bool,默认为False)

如果为 True,则返回每个点的邻居距离;如果为 False,则只返回邻居

query() 方法不同,这里设置 return_distance=True 会增加计算时间。如果 return_distance=False,并不需要显式计算所有距离

count_only

(bool,默认为False)

如果为 True,则只返回距离 r 内的点的数量;

如果为 False,则返回距离 r 内所有点的索引

sort_results

(bool,默认为False)

如果为 True,则在返回之前对距离和索引进行排序。如果为 False,则结果不排序

import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
tree = BallTree(X)                
tree.query_radius(X[:3],r=0.3)
'''
array([array([ 0, 68, 11, 31, 46, 19, 36, 63, 16, 86, 79], dtype=int64),array([26, 64, 20, 94,  1,  4, 13,  3], dtype=int64),array([35, 50, 30, 83, 85, 18, 15, 53,  2, 96, 81], dtype=int64)],dtype=object)
'''

2.5 two_point_correlation

计算距离小于等于r[i]的点的数量

two_point_correlation(X, r, dualtree=False)
X要查询的点集
r一维数组,包含距离值
dualtree

如果为 True,则使用双树算法;否则,使用单树算法。

对于大量数据点(N),双树算法可能有更好的扩展性

返回值

counts (ndarray): counts[i] 包含距离小于或等于 r[i] 的点对数

import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
r=np.linspace(0.1,1,5)
tree = BallTree(X)                
tree.two_point_correlation(X[:3],r=r)
#array([  4,  34,  99, 196, 263], dtype=int64)
'''
返回的第一个值:和X[0]的距离小于r[0]的数量+和X[1]的距离小于r[0]的数量+和X[2]的距离小于r[0]的数量
'''

3 KD-Tree

和Ball-Tree 一模一样


http://www.ppmy.cn/news/1221171.html

相关文章

大数据-之LibrA数据库系统告警处理(ALM-12047 网络读包错误率超过阈值)

告警解释 系统每30秒周期性检测网络读包错误率&#xff0c;并把实际错误率和阈值&#xff08;系统默认阈值0.5%&#xff09;进行比较&#xff0c;当检测到网络读包错误率连续多次&#xff08;默认值为5&#xff09;超过阈值时产生该告警。 用户可通过“系统设置 > 阈值配置…

LC349. 两个数组的交集

/*** 方法一* 创建set1和set2来装nums1和nums2中的元素* 创建set3来装交集数据* 我们遍历set1,取出每一个item,如果set2中含有该item,则该item就是set1和set2共有的元素,因此将item放入set3中* 最后将set3转换成数组返回* param nums1* param nums2* return*/public static int…

反序列化漏洞(2), 分析调用链, 编写POC

反序列化漏洞(2), 反序列化调用链分析 一, 编写php漏洞脚本 http://192.168.112.200/security/unserial/ustest.php <?php class Tiger{public $string;protected $var;public function __toString(){return $this->string;}public function boss($value){eval($valu…

R脚本进行长宽数据转换

1.R脚本进行长宽数据转换 library(tidyverse) df tibble(Class c("1班", "2班"),Name c("张三&#xff0c;李四&#xff0c;王五", "赵六&#xff0c;钱七")) df## # A tibble: 2 x 2 ## Class Name ## <chr> <chr&g…

DbUtils示例

DbUtils:JDBC实用程序组件示例 本页提供了一些示例&#xff0c;说明如何使用Dbutils。 基本用途 DbUtils是一个非常小的类库&#xff0c;因此不需要很长时间就可以遍历每个类的javadocs。DbUtils中的核心类/接口是QueryRunner和ResultSetHandler。您不需要了解任何其他DbUti…

虾皮之家数据分析插件:知虾数据分析工具提升销量的利器

在当今的电商市场中&#xff0c;虾皮Shopee成为了许多商家的首选平台。然而&#xff0c;随着竞争的加剧&#xff0c;店铺运营变得越来越具有挑战性。如何提升销量&#xff0c;优化标题和图片&#xff0c;合理设置SKU&#xff0c;并准确跟踪店铺活动数据和竞品数据&#xff0c;已…

ERROR: column “xxxx.id“ must appear in the GROUP BY

org.postgresql.util.PSQLException: ERROR: column “xxx.id” must appear in the GROUP BY clause or be used in an aggregate function 错误**&#xff1a;列“XXXX.id”必须出现在GROUP BY子句中或在聚合函数中使用** 出现这种错误的sql如下&#xff1a; select name,…

我的项目分享(不喜勿喷)

我要分享的项目是大喇叭C2C电商平台系统&#xff0c;一个面向移动端的电子商务平台&#xff0c;为个体消费者和商家提供直接交易和沟通的便利&#xff0c;丰富了人们的生活。 主要功能模块&#xff1a; 该项目的主要功能包括&#xff1a; 1. 用户注册功能&#xff1a;使用正则…