昇思训练营打卡第十八天(K近邻算法实现红酒聚类)

ops/2024/10/21 1:58:14/

K近邻(K-Nearest Neighbors,KNN)算法是一种基本的机器学习算法,它既可以用于分类任务,也可以用于回归任务。KNN算法的核心思想是,如果一个新样本在特征空间中的K个最邻近的样本大多数属于某一个类别,那么这个新样本也属于这个类别

KNN算法的基本步骤

  1. 选择距离度量:确定样本之间的距离计算方法,常用的距离度量方法有欧氏距离、曼哈顿距离等。

  2. 确定邻居的数量K:K值的选择对KNN算法的结果有重要影响,K值通常需要通过交叉验证等方法来确定。

  3. 选择训练样本:训练样本应该能够代表整个数据集的特性。

  4. 进行分类

    • 对于一个新的输入实例,计算它与训练集中每一个实例的距离。
    • 选择距离最近的K个实例。
    • 根据这K个实例的标签,通过多数投票等方式,确定新实例的类别。

KNN算法的优缺点

优点:
  • 简单易懂,易于实现。
  • 不需要训练模型,因此对于训练数据没有假设,适用于各种类型的决策边界。
  • 可以适用于多分类问题。
缺点:
  • 计算量大,因为需要计算每个测试样本与所有训练样本的距离。
  • 对噪声和离群点敏感,因为它们会影响近邻的选择。
  • 需要预先确定K值,K值的选择对结果有较大影响。

应用场景

KNN算法在现实世界中广泛应用于模式识别、文本分类、图像识别等领域,尤其是在数据分布较为稠密且特征维度不高的情况下表现良好。

from download import download# 下载红酒数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip"  
path = download(url, "./", kind="zip", replace=True)
%matplotlib inline
import os
import csv
import numpy as np
import matplotlib.pyplot as pltimport mindspore as ms
from mindspore import nn, opsms.set_context(device_target="CPU")
with open('wine.data') as csv_file:data = list(csv.reader(csv_file, delimiter=','))
print(data[56:62]+data[130:133])
X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
Y = np.array([s[0] for s in data[:178]], np.int32)
attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols','Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue','OD280/OD315 of diluted wines', 'Proline']
plt.figure(figsize=(10, 8))
for i in range(0, 4):plt.subplot(2, 2, i+1)a1, a2 = 2 * i, 2 * i + 1plt.scatter(X[:59, a1], X[:59, a2], label='1')plt.scatter(X[59:130, a1], X[59:130, a2], label='2')plt.scatter(X[130:, a1], X[130:, a2], label='3')plt.xlabel(attrs[a1])plt.ylabel(attrs[a2])plt.legend()
plt.show()
train_idx = np.random.choice(178, 128, replace=False)
test_idx = np.array(list(set(range(178)) - set(train_idx)))
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]
class KnnNet(nn.Cell):def __init__(self, k):super(KnnNet, self).__init__()self.k = kdef construct(self, x, X_train):#平铺输入x以匹配X_train中的样本数x_tile = ops.tile(x, (128, 1))square_diff = ops.square(x_tile - X_train)square_dist = ops.sum(square_diff, 1)dist = ops.sqrt(square_dist)#-dist表示值越大,样本就越接近values, indices = ops.topk(-dist, self.k)return indicesdef knn(knn_net, x, X_train, Y_train):x, X_train = ms.Tensor(x), ms.Tensor(X_train)indices = knn_net(x, X_train)topk_cls = [0]*len(indices.asnumpy())for idx in indices.asnumpy():topk_cls[Y_train[idx]] += 1cls = np.argmax(topk_cls)return cls
acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):pred = knn(knn_net, x, X_train, Y_train)acc += (pred == y)print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))


http://www.ppmy.cn/ops/55825.html

相关文章

Vite配置环境变量以及动态更新html数据

一、设置配置文件 // .env // 公共配置文件,总是生效 VITE_BASE_API_URLhttp://localhost:3000// .env.development VITE_BASE_API_URL/api VITE_TAB_TITLEdevelopment title// .env.production VITE_BASE_API_URL/api VITE_TAB_TITLEproduction title 二、安装插…

springcloud 面试经常被问问题

Spring Cloud 是一个基于 Spring Boot 的微服务架构解决方案,包含了许多用于构建和管理微服务的工具和框架。在面试中,与 Spring Cloud 相关的问题通常会涉及其核心概念、组件、常用模式和解决方案。以下是一些在 Spring Cloud 面试中经常被问到的问题及…

字节跳动 AML 前端 一面

时长55mins 1. 自我介绍 1. 怎么接触的前端?学了多久? 1. 问项目 1. 为什么要做组件库? 1. 问到我的组件库和AntD之类的有什么区别,我说区别可能就是我的功能更少?hhhh 1. 设计一个组件的思路&#x…

从零开始学数据结构系列之第四章《 图的遍历总代码》

文章目录 概念回顾深度优先遍历(DFS)概念图的深度优先遍历深度优先遍历算法步骤 广度优先遍历(BFS)概念广度优先遍历算法步骤 程序总代码往期回顾 概念回顾 ​   图的遍历是和树的遍历类似,我们希望从图中某一顶点出…

LVS+Keepalived集群

Keepalived双机热备 Keepalived实现原理刨析 Keepalived采用VRRP热备份协议实现Linux服务器的多机热备功能 Keepalived案例分析 Keepalived可以实现多机热备,每个热备组可有多台服务器 双机热备的故障切换是由虚拟IP地址的漂移来实现,适用于各种应用…

paraview将raw数据转为vtk

ParaView 是一个强大的可视化工具,可以转换各种数据格式,包括将原始数据转换为 VTK 文件格式。以下是一个简单的 Python 脚本,使用 ParaView Python 接口来转换 raw 数据为 VTK 文件: # 导入ParaView模块 import paraview from p…

ARTS Week 36

unsetunsetAlgorithmunsetunset 本周的算法题为 1528. 重新排列字符串 给你一个字符串 s 和一个 长度相同 的整数数组 indices 。 请你重新排列字符串 s ,其中第 i 个字符需要移动到 indices[i] 指示的位置。 返回重新排列后的字符串。 img 示例 1:输入&…

VPN是什么?

VPN,全称Virtual Private Network,即“虚拟私人网络”,是一种在公共网络(如互联网)上建立加密、安全的连接通道的技术。简单来说,VPN就像是一条在公共道路上铺设的“秘密隧道”,通过这条隧道传输…