KNN(K近邻算法)

devtools/2024/10/21 5:48:42/

  k近邻,顾名思义,就是寻找距离测试点最近的 k 个点,根据这 k 个点的标签来判断该测试点的标签。
  如下图所示,图中有10个样本点,若要对图中的绿点1分类,k算法>近邻算法采用的策略是(下图中 k 值为 3 ),找到距离绿点1最近的三个点,其分别是 2、3、4,其中 2 为蓝色,3、4为红色,因为红色占大多数,所以 1 就会被分类到红色阵营里面。

  对于样本点距离的计算,一般是采用欧几里得距离,在二维特征情况下,其计算公式为: ( x 2 2 − x 1 2 ) + ( y 2 2 − y 1 2 ) \sqrt{(x_2^2-x_1^2)+(y_2^2-y_1^2)} (x22x12)+(y22y12) 但对于多维的情况下,计算公式为: ∑ i = 1 n ( q i 2 − p i 2 ) \sqrt{\sum_{i=1}^{n}(q_i^2-p_i^2)} i=1n(qi2pi2) 其中 n 为样本数据的特征维度。

KNN实现代码

from collections import Counter
import numpy as npdef euclidean_distance(x1, x2):return np.sqrt(np.sum((x1 - x2) ** 2))class KNN:def __init__(self, k=3):self.k = kdef fit(self, X, y):self.X_train = Xself.y_train = ydef predict(self, X):# 使用_predict函数对每个测试样本进行预测,得到每个样本的标签predicted_labels = [self._predict(x) for x in X]return np.array((predicted_labels))def _predict(self, x):# 计算测试点到每个训练点的距离distances = [euclidean_distance(x, x_train) for x_train in self.X_train]# 获取K个最近的样本及标签k_indices = np.argsort(distances)[:self.k]k_nearest_labels = [self.y_train[i] for i in k_indices]# 获取数量最多的标签most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]

对于其中的某些函数:

argsort函数:它会对数组进行排序,但返回的不是排序后的数组,而是原数组中每个元素在排序后数组中的位置索引。

import numpy as np  arr = np.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5])  
sorted_indices = np.argsort(arr)  print("原数组:", arr)  
print("排序后的索引:", sorted_indices)  
print("按索引排序后的数组:", arr[sorted_indices])输出:
原数组: [3 1 4 1 5 9 2 6 5 3 5]  
排序后的索引: [ 1  3  6  0  9  2  4  8 10  5]  
按索引排序后的数组: [1 1 2 3 3 4 5 5 5 6 9]

Counter函数:统计每个元素及其出现的次数。

from collections import Counter  words = ['apple', 'banana', 'apple', 'orange', 'banana', 'grape']  
word_counts = Counter(words)  print("单词计数:", word_counts)  
print("最常见的单词:", word_counts.most_common(1))  # 获取出现次数最多的单词输出:
单词计数: Counter({'apple': 2, 'banana': 2, 'orange': 1, 'grape': 1})  
最常见的单词: [('apple', 2)]

KNN 测试程序

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn import datasets
from sklearn.model_selection import train_test_splitiris = datasets.load_iris()
X, y = iris.data, iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)# (120, 4)共有120个数据,其中每个数据的维度为4
# print(X_train.shape)
# print(X_test.shape)
# [5.1 2.5 3.  1.1]通过观察可发现第一个样本有四个特征
# print(X_train[0])
# 训练标签共有120个数据,每个数据都是一维
# print(y_train.shape)
# 展示所有的标签
# print(y_train)# 把数据的前两维特征以点状图呈现出来,其中颜色按照标签的不同进行分类
plt.figure()
cmap = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k', s=20)
plt.show()from KNN import KNNclf = KNN(3)
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)acc = np.sum(predictions == y_test) / len(y_test)
print(acc)

数据样本前二维分布为:

输出结果为 1.0 可见 KNN 在此分类问题中还是有不错效果的。

本文参考视频:KNN


http://www.ppmy.cn/devtools/127488.html

相关文章

探索Python中的多线程与多进程

在Python编程中,多线程和多进程是两个重要的概念,它们被用来提高程序的执行效率。本文将深入探讨这两个概念,并对比它们在Python中的实现方式。 一、多线程 多线程是一种并发执行的程序设计方法。在Python中,我们可以使用thread…

[Python学习日记-52] Python 中的 copy 模块 —— shutil

[Python学习日记-52] Python 中的 copy 模块 —— shutil 简介 shutil 模块 简介 在前面的学习当中,我们学习了如何在 Python 中创建文件,这个时候我们基本已经有了写程序的能力了,而有的时候我们也想使用 Python 来对文件进行更多的操作&a…

QT 软件打包为一个单独可执行.exe文件

将 QT 应用程序打包为一个独立的可执行文件 (.exe) 以便于分发通常包括以下几个步骤。以下是详细的流程和说明: 1. 准备环境 确保已经安装了以下软件: Qt SDK:可以从 Qt 官网 下载。Qt Creator:通常包含在 Qt SDK 中。MinGW 或…

一起搭WPF架构之livechart的MVVM使用介绍

一起搭WPF架构之livechart使用介绍 前言ModelViewModelView界面设计界面后端 效果总结 前言 简单的架构搭建已经快接近尾声了,考虑设计使用图表的形式将SQLite数据库中的数据展示出来。前期已经介绍了livechart的安装,今天就详细介绍一下livechart的使用…

关于Vue脚手架

一、简介与安装 1 简介 Vue Cli 全称Vue command line interface(Vue命令行接口),俗称Vue脚手架, 是Vue官方提供的一个标准化开发工具(开发平台)。 可以帮助我们快速创建一个开发Vue项目的标准化基础架子。【集成了webpack配置】 参考官网&#xff1a…

使用 surya-ocr 进行文字识别

surya-ocr 是一个开源的 OCR 模型,个人用是免费的,商用是需要License,收费标准有些复杂,具体可以查看官网。 主要包括以下功能: 支持 90 多种语言的 OCR任何语言的行级文本检测版面分析(表格、图像、标题等…

高级java每日一道面试题-2024年10月12日-Web篇-http,servlet,tomcat 之间是什么关系?

如果有遗漏,评论区告诉我进行补充 面试官: http,servlet,tomcat 之间是什么关系? 我回答: HTTP(超文本传输协议)、Servlet 和 Tomcat 之间的关系可以这样理解:它们是构建Web应用程序的不同层次的技术。下…

帝国cms取得内容和栏目链接地址的方法

用以下2个函数解决内容页面和栏目页面链接,可有效解决更改URL显示方式(动态、静态、伪静态)不需要修改模版中的链接地址。 内容页链接地址: $infourlsys_ReturnBqTitleLink($r); $r为含“id,classid,newspath,filename,groupid,ti…