机器学习系列----KNN分类

devtools/2024/11/18 5:28:44/

目录

前言

一.KNN算法的基本原理

二.KNN分类的实现

三.总结


前言

机器学习领域,K近邻算法(K-Nearest Neighbors, KNN)是一种非常直观且常用的分类算法。它是一种基于实例的学习方法,也被称为懒学习(Lazy Learning),因为它在训练阶段不进行任何模型的构建,所有的计算都推迟到测试阶段进行。KNN分类的核心思想是:给定一个测试样本,找到在训练集中与其距离最近的K个样本,然后根据这K个样本的标签进行预测。

本文将介绍KNN算法的基本原理、如何实现KNN分类,以及在实际使用中需要注意的几点。

一.KNN算法的基本原理

KNN算法的基本流程如下:

(1)选择距离度量:通常我们使用欧氏距离来衡量两个样本点之间的距离,但也可以选择其他距离度量,如曼哈顿距离、余弦相似度等。

(2)选择K值:选择K的大小会直接影响分类效果。K值太小容易受到噪声数据的影响,而K值过大可能导致分类结果过于平滑。

(3)找到K个邻居:对于测试样本,根据距离度量选择与之最接近的K个样本。

(4)投票决策:通过这K个邻居的类别标签进行投票,测试样本的预测标签通常由出现频率最高的类别决定。

二.KNN分类的实现

在Python中,我们可以通过 sklearn 库来快速实现KNN分类器。下面是一个使用KNN进行分类的基本示例:

导入必要的库

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

 加载数据集
我们使用sklearn自带的鸢尾花(Iris)数据集,该数据集包含150个样本,4个特征,3个类别。

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征
y = iris.target  # 标签

数据集拆分
将数据集拆分为训练集和测试集,训练集占80%,测试集占20%。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

初始化KNN分类器并训练
我们创建一个KNN分类器实例,选择K=3。

# 初始化KNN分类器,设置K=3
knn = KNeighborsClassifier(n_neighbors=3)# 在训练集上训练模型
knn.fit(X_train, y_train)

测试与评估
我们可以使用测试集来评估模型的准确性。

# 在测试集上做预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy * 100:.2f}%")

KNN分类的优缺点
优点
简单易懂:KNN算法简单直观,不需要复杂的训练过程。
无参数假设:KNN不需要像其他算法那样对数据做参数假设,适应性强。
适合多分类问题:KNN能够有效地处理多类问题。
缺点
计算开销大:在测试阶段需要计算每个测试点与所有训练数据的距离,计算量大,尤其在数据量较大时,效率较低。
对噪声敏感:由于KNN依赖于距离度量,数据中的噪声点可能会影响分类结果。
需要存储整个训练集:KNN算法是懒学习,需要将训练集存储在内存中,可能会对内存消耗产生较大影响。
K值的选择与调优
选择合适的K值是KNN分类器表现的关键。过小的K值(例如1)容易过拟合,受噪声影响较大,而过大的K值会导致欠拟合。常用的选择方法是通过交叉验证来选择K值。 

from sklearn.model_selection import cross_val_score# 使用交叉验证选择K值
k_values = range(1, 21)
cv_scores = [np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=k), X, y, cv=5)) for k in k_values]# 输出不同K值的交叉验证得分
for k, score in zip(k_values, cv_scores):print(f"K={k}, Cross-validation accuracy={score:.2f}")

 

 

import numpy as np
import math
from collections import Counter
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt# 计算欧几里得距离
def euclidean_distance(x1, x2):"""计算欧几里得距离x1, x2: 两个输入样本,numpy数组或列表"""return math.sqrt(np.sum((x1 - x2) ** 2))# 计算曼哈顿距离
def manhattan_distance(x1, x2):"""计算曼哈顿距离(L1距离)x1, x2: 两个输入样本,numpy数组或列表"""return np.sum(np.abs(x1 - x2))# 计算闵可夫斯基距离
def minkowski_distance(x1, x2, p=3):"""计算闵可夫斯基距离,p为距离的阶数x1, x2: 两个输入样本,numpy数组或列表p: 阶数,通常为 1 (曼哈顿距离) 或 2 (欧几里得距离)"""return np.power(np.sum(np.abs(x1 - x2) ** p), 1/p)# KNN 分类器类
class KNN:def __init__(self, k=3, distance_metric='euclidean'):"""初始化 KNN 分类器k: 最近邻的个数distance_metric: 距离度量方式,'euclidean' 为欧几里得距离,'manhattan' 为曼哈顿距离,'minkowski' 为闵可夫斯基距离"""self.k = kself.distance_metric = distance_metricdef fit(self, X_train, y_train):"""训练模型,保存训练数据X_train: 训练特征数据y_train: 训练标签数据"""self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):"""对测试数据进行预测X_test: 测试特征数据返回预测标签"""predictions = [self._predict(x) for x in X_test]return np.array(predictions)def _predict(self, x):"""对单个样本进行预测x: 输入样本返回预测标签"""# 根据指定的距离度量方法计算距离if self.distance_metric == 'euclidean':distances = [euclidean_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'manhattan':distances = [manhattan_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'minkowski':distances = [minkowski_distance(x, x_train) for x_train in self.X_train]else:raise ValueError(f"Unsupported distance metric: {self.distance_metric}")# 找到最近的 k 个邻居k_indices = np.argsort(distances)[:self.k]k_nearest_labels = [self.y_train[i] for i in k_indices]# 返回最常见的标签most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]# 加载 Iris 数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 切分数据集为训练集和测试集,70% 训练集,30% 测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 标准化数据,以确保不同特征的数值范围一致
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 初始化 KNN 分类器
knn = KNN(k=5, distance_metric='minkowski')  # 使用闵可夫斯基距离
knn.fit(X_train, y_train)# 预测
predictions = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print(f"预测准确率: {accuracy * 100:.2f}%")# 输出混淆矩阵和分类报告
print("\n混淆矩阵:")
print(confusion_matrix(y_test, predictions))print("\n分类报告:")
print(classification_report(y_test, predictions))# 绘制预测结果与真实结果对比的图表
def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues):"""绘制混淆矩阵cm: 混淆矩阵classes: 类别名"""plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)# 绘制网格thresh = cm.max() / 2.for i, j in np.ndindex(cm.shape):plt.text(j, i, format(cm[i, j], 'd'),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')# 计算混淆矩阵
cm = confusion_matrix(y_test, predictions)# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
plot_confusion_matrix(cm, classes=iris.target_names)
plt.show()# 展示一些预测结果
for i in range(5):print(f"实际标签: {iris.target_names[y_test[i]]}, 预测标签: {iris.target_names[predictions[i]]}")

 

代码功能解释:
计算不同距离:

euclidean_distance:计算欧几里得距离。
manhattan_distance:计算曼哈顿距离。
minkowski_distance:计算闵可夫斯基距离,p 代表阶数,通常取 1(曼哈顿距离)或者 2(欧几里得距离)。
KNN 分类器:

在 KNN 类中,你可以选择不同的距离度量方式 ('euclidean', 'manhattan', 'minkowski'),通过 k 来设定邻居个数。
fit 方法保存训练数据,predict 方法对每个测试数据点进行预测。
_predict 方法对单个测试样本进行预测,通过计算与训练集中所有样本的距离来选择最近的 k 个邻居。
数据预处理:

使用 StandardScaler 来标准化数据,使得每个特征具有零均值和单位方差。
模型评估:

使用 accuracy_score 计算预测准确率。
使用 confusion_matrix 和 classification_report 来展示混淆矩阵和分类性能报告(包括精确度、召回率、F1 分数等)。
通过 matplotlib 绘制混淆矩阵,帮助可视化模型的分类效果。
数据集:

使用 sklearn.datasets 中的 Iris 数据集。该数据集包含 150 个样本,分别属于 3 个不同的鸢尾花种类,每个样本有 4 个特征。
输出:
准确率:模型对测试集的预测准确性。
混淆矩阵:展示真实标签与预测标签的对比。
分类报告:包含精确度、召回率、F1 分数等详细指标。
混淆矩阵图表:图形化展示分类性能。
这个实现包含了更多的功能,并且通过使用不同的距离度量方法,你可以探索 KNN 在不同设置下的表现。

三.总结

K 最近邻(KNN)算法是一种简单直观的监督学习算法,广泛应用于分类和回归问题。其核心思想是,通过计算待预测样本与训练集中的每个样本之间的距离,选择距离最近的 k 个样本(即“邻居”),然后根据这些邻居的标签或数值来进行预测。在分类问题中,KNN 通过多数投票原则决定最终分类结果;在回归问题中,则通常是取邻居标签的平均值。KNN 算法的优势在于不需要显式的训练过程,其预测过程依赖于对整个训练数据集的存储和计算,因此适合动态更新数据的场景。然而,KNN 算法的计算复杂度较高,尤其在数据集较大时,预测过程可能变得非常缓慢。为了提高效率,通常需要对数据进行预处理,如归一化或标准化,以消除不同特征尺度差异的影响。此外,K 值的选择以及距离度量方法(如欧几里得距离、曼哈顿距离等)会显著影响模型的表现,K 值过小可能导致过拟合,过大则可能导致欠拟合。KNN 的一个主要缺点是它对高维数据(即特征空间维度较大)不太敏感,因为高维空间的距离度量往往会失去区分度,导致“维度灾难”。总的来说,KNN 是一个易于理解和实现的算法,适用于样本量不大且特征维度较低的问题,但在大数据集和高维数据上可能不够高效。


http://www.ppmy.cn/devtools/134876.html

相关文章

Jmeter中的前置处理器(二)

5--JDBC PreProcessor 功能特点 数据库操作:执行SQL语句,支持插入、删除、更新和查询操作。灵活配置:可以连接多种数据库(如MySQL、Oracle、PostgreSQL等)。适用于数据准备:特别适合需要在测试前准备数据…

docker安装宝塔,Mac也可以使用宝塔搭建开发环境了

宝塔没有mac版本,如果想在Mac本地搭建宝塔环境做php开发的,可以使用docker的方式部署 新建一个目录baota,目录下新建文件docker-compose.yml services:dzhbt-ubuntu:image: registry.cn-heyuan.aliyuncs.com/gzdzh/baota:dzhbt-ubuntu-arm-l…

编写一个生成凯撒密码的程序

plain list(input("请输入需要加密的明文(只支持英文字母):"))key int(input("请输入移动的位数:"))base_A ord(A)base_a ord(a)cipher []for each in plain:if each :cipher.append( )else:if each.i…

某杀软环境下的添加账户

某杀软环境下的添加账户 我们在某个杀软环境下,正常添加账户一般是会被直接拦截的 白+黑 在这个环境下,白+黑是最实用的绕过方式,我们可以通过调用winapi来创建账户,这些代码再存储到dll里面&#xff0c…

SQL Server Management Studio 的JDBC驱动程序和IDEA 连接

一、数据库准备 (一)启用 TCP/IP 协议 操作入口 首先,我们要找到 SQL Server 配置管理器,操作路径为:通过 “此电脑” 右键选择 “管理”,在弹出的 “计算机管理” 窗口中,找到 “服务和应用程…

Linux解决 -bash: nc: command not found-bash: nc: 未找到命令

[rootbigdata01 ~]# nc -lk 9999 -bash: nc: command not found(-bash: nc: 未找到命令) 没有命令安装一下即可 yum install -y nc [rootbigdata01 ~]# yum install -y nc 已加载插件:fastestmirror Determining fastest mirrors * base: mi…

麒麟nginx配置

一、配置负载均衡 配置麒麟的yum源 vim /etc/yum.repos.d/kylin_aarch64.repo Copy 删除原来内容,写入如下yum源 [ks10-adv-os] name Kylin Linux Advanced Server 10 - Os baseurl http://update.cs2c.com.cn:8080/NS/V10/V10SP2/os/adv/lic/base/aarch64/ …

SpringCloud Feign 报错 Request method ‘POST‘ not supported 的解决办法

Request method POST not supportedorg.springframework.web.HttpRequestMethodNotSupportedException: Request method POST not supported解决办法: 在远程调用fegin使用GET请求时 应该附加注解 RequestParam(“pgQuery”) 实体类或者单个参数同样适用 在controller接受参数…