机器学习:k近邻

devtools/2025/2/21 10:40:52/

所有代码和文档均在golitter/Decoding-ML-Top10: 使用 Python 优雅地实现机器学习十大经典算法。 (github.com),欢迎查看。

K 邻近算法(K-Nearest Neighbors,简称 KNN)是一种经典的机器学习算法,主要用于分类和回归任务。它的核心思想是:给定一个新的数据点,通过查找训练数据中最接近的 K 个邻居,并根据这些邻居的标签来预测新数据点的标签。

KNN 是一种 基于实例的学习(Instance-based learning)算法。在训练阶段,它并不构建显式的模型,而是将训练数据存储起来,在预测阶段计算待预测点与训练集中所有点的距离,然后选择 K 个最近的邻居,根据邻居的标签进行投票或平均来做出预测。

KNN 的优点在于其简单易懂、无需训练过程,并且适用于大多数任务。它能够处理复杂的非线性问题,不依赖数据分布假设,能够很好地适应复杂的决策边界。

然而,KNN 的缺点也很明显。它的计算开销大,因为每次预测都需要计算所有训练数据的距离,导致在大数据集上表现不佳。此外,KNN 需要存储所有训练数据,占用较大的内存空间,并且对异常值敏感,可能会影响预测结果的准确性。

KNN算法步骤:

  1. 选择 K 个邻居的数量,K 值通常是一个奇数,以避免平票的情况。
  2. 计算待预测数据点与训练数据集中每个点的距离。
  3. 根据计算出的距离选择 K 个最接近的点。
  4. 对于分类任务,返回 K 个邻居中最多的类别;对于回归任务,返回 K 个邻居标签的均值。

代码实现

数据处理:使用iris.data数据集,用PCA进行降维。

import numpy as np
import pandas as pddef pca(X: np.array, n_components: int) -> np.array:"""PCA 进行降维。"""# 1. 数据标准化(去均值)X_mean = np.mean(X, axis=0)X_centered = X - X_mean# 2. 计算协方差矩阵covariance_matrix = np.cov(X_centered, rowvar=False)# 3. 计算特征值和特征向量eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)# 4. 按特征值降序排序sorted_indices = np.argsort(eigenvalues)[::-1]top_eigenvectors = eigenvectors[:, sorted_indices[:n_components]]# 5. 投影到新空间X_pca = np.dot(X_centered, top_eigenvectors)return X_pcadef get_data():data = pd.read_csv('iris.csv', header=None)# print(data.dtypes)unq = data.iloc[:, -1].unique()for i, u in enumerate(unq):data.iloc[:, -1] = data.iloc[:, -1].apply(lambda x: i if x == u else x)# print(data.sample(5))xuanze = np.random.choice([True, False], len(data), replace=True, p=[0.8, 0.2])train_data = data[xuanze]test_data = data[~xuanze]train_data = np.array(train_data,dtype=np.float32,)test_data = np.array(test_data, dtype=np.float32)# 归一化train_data[:, :-1] = (train_data[:, :-1] - train_data[:, :-1].mean(axis=0)) / train_data[:, :-1].std(axis=0)test_data[:, :-1] = (test_data[:, :-1] - test_data[:, :-1].mean(axis=0)) / test_data[:, :-1].std(axis=0)return (pca(train_data[:, :-1], 2),train_data[:, -1].astype(np.int32),pca(test_data[:, :-1], 2),test_data[:, -1].astype(np.int32),)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_data()print(y_train.dtype)print(x_test, y_test)print(x_train.shape, y_train.shape)

knn过程:

from data_processing import get_data
import numpy as np
import matplotlib.pyplot as pltdef euclidean_distance(x_train: np.array, x_test: np.array) -> np.array:"""计算欧拉距离"""return np.sqrt(np.sum((x_train - x_test) ** 2, axis=1))def knn(k: int, x_train: np.array, y_train: np.array, x_test: np.array) -> np.array:"""k近邻算法"""predictions = []for test in x_test:distances = euclidean_distance(x_train, test)nearest_indices = np.argsort(distances)[:k]  # 返回最近的k个点的索引nearest_labels = y_train[nearest_indices]  # 返回最近的k个点的标签prediction = np.argmax(np.bincount(nearest_labels))  # 返回最近的k个点中出现次数最多的标签predictions.append(prediction)return np.array(predictions)def accuracy(predictions: np.array, y_test: np.array) -> float:"""计算准确率"""return np.sum(predictions == y_test) / len(y_test)if __name__ == '__main__':k = 5x_train, y_train, x_test, y_test = get_data()predictions = knn(k, x_train, y_train, x_test)acc = accuracy(predictions, y_test)print(f'准确率为: {acc * 100:.2f}')# 绘制训练数据plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap='viridis', marker='o', label='Train Data', alpha=0.7)# 绘制测试数据plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap='coolwarm', marker='x', label='Test Data', alpha=0.7)# 绘制预测结果plt.scatter(x_test[:, 0],x_test[:, 1],c=predictions,cmap='coolwarm',marker='.',edgecolor='black',alpha=0.7,label='Predictions',)# 添加标题和标签plt.title('KNN Classification Results')plt.xlabel('Feature 1')plt.ylabel('Feature 2')plt.legend()# 显示图形plt.show()

在这里插入图片描述


http://www.ppmy.cn/devtools/160198.html

相关文章

首都国际会展中心启用,首展聚焦汽车后市场全产业链

首都国际会展中心启用,首展聚焦汽车后市场全产业链 2025年2月21日-24日,首都国际会展中心(新国展二期)迎来了其启用后的首场大型展览——第36届中国国际汽车服务用品及设备展览会暨中国国际新能源汽车技术、零部件及服务展览会&a…

【合集】Java进阶——Java深入学习的笔记汇总 再论面向对象、数据结构和算法、JVM底层、多线程、类加载、

前言 spring作为主流的 Java Web 开发的开源框架,是Java 世界最为成功的框架,持续不断深入认识spring框架是Java程序员不变的追求;而spring的底层其实就是Java,因此,深入学习Spring和深入学习Java是硬币的正反面&…

Spark 性能优化(四):Cache

在 Spark 中,缓存是一种将计算结果存储在内存中的方式,目的是加速后续操作。当你执行迭代算法或查询时,如果多次重复使用相同的数据集,缓存可以避免每次都重新计算相同的转换操作。通过缓存,Spark 可以将数据存储在内存…

AWS SES 邮件服务退信/投诉处理与最佳实践指南

在使用 AWS SES 发送邮件时,合理处理退信和投诉是维护发送声誉的关键。本文将详细介绍 SES 中的退信/投诉处理机制以及相关最佳实践。 一、退信处理机制 © ivwdcwso (ID: u012172506) 1.1 退信类型 在 SES 中,退信分为两种类型: 硬退信(Hard Bounce) 永久性错误,如无效…

【ISO 14229-1:2023 UDS诊断(ECU复位0x11服务)测试用例CAPL代码全解析⑪】

ISO 14229-1:2023 UDS诊断【ECU复位0x11服务】_TestCase11 作者:车端域控测试工程师 更新日期:2025年02月18日 关键词:UDS诊断协议、ECU复位服务、0x11服务、ISO 14229-1:2023 TC11-011测试用例 用例ID测试场景验证要点参考条款预期结果TC…

Python 自然语言处理(NLP)和文本挖掘的常规操作过程

Python 自然语言处理(NLP)和文本挖掘 自然语言处理(NLP)和文本挖掘是数据科学中的重要领域,涉及对文本数据的分析和处理。Python 提供了丰富的库和工具,用于执行各种 NLP 和文本挖掘任务。以下是一些常见的…

重装CentOS YUM

1. 检查是否已安装 YUM 运行以下命令检查 YUM 是否已安装: yum list installed | grep yum 如果输出中包含 yum,则说明 YUM 已安装。 2. 卸载旧版本的 YUM(如有必要) 如果需要重新安装 YUM,可以先卸载旧版本&…

Docker 镜像加速器配置指南

Docker 镜像加速器配置指南 2025-02-17 23:00 Linux : Aliyun ECS 服务器 背景问题 在国内,由于网络环境的不稳定,直接从 Docker Hub 拉取镜像的速度可能会很慢,有时甚至会失败。即使配置了官方的阿里云镜像加速器,也可能因为…