【机器学习】快速有效理解 K-Means 算法

news/2024/11/29 7:42:49/

什么是 K-Means ?

学习 K-Means 之前,大家首先需要对聚类有一个概念.

我们都知道,机器学习可以划分为 3 类:监督学习、无监督学习、强化学习.

无监督学习指的是数据没有标签,也就是说我们只有数据的特征,但并不知道这些数据都是什么,无监督学习算法或者是模型需要从这样的数据中学习给数据按照某种规律进行分类的能力,或者是找出不同特征之间的关联性等等.

聚类(Clustering)也归属于无监督学习,通过对所有数据进行特征分析,然后把相似的对象划分到同一个集合中,这个集合我们称之为簇.

它需要给一堆没有数据标签的数据找到给定数量的簇,每个簇的周围聚拢了许多数据,所以就可以近似认为聚焦在同一个簇周围的数据就是一个类别,这样达到了分类的效果.

比如,给一张电子表格给你,里面有 10000 个成员的信息.这些信息包含了 5 个特征,分别是性别\年龄\身高\体重\颜值.

姓名性别年龄身高体重颜值
person002417760100
person10201777098
person21291777089
person30241777060
person41201707078

现在,给你的任务就是强行给这 10000 人划分为 3 组,划分的过程就叫做聚类.

而 K-means 就是一种可以完成这样任务的算法.

如何理解 K-Means?

K 是指目标的数量,也是刚刚提到簇的数量,比如你要分 3 个类别的话,K = 3.
K 的大小不固定,按照实际需求而定.

Means 是均值的意思,这说明 K-Means 和数据的均值息息相关.

在 K-means 中,每个簇由它对应的质心(centroid)表示,质心是这个簇中所有点的中心,所谓均值也就是这个意思.

值得注意的是我们这里讲的点是指代数据集中的每一条记录的,每条记录其实就是一个向量,所以质心也不是真正的一个点,它也是一个向量.

算法思想

K-means 具体算法是什么呢?

其实非常的简单和容易被人理解.

  1. 随机生成 k 个质心.
  2. 将数据集中每条记录,在这里可以当成一个点,分配到一个簇当中.分配的原则是将点对照 k 个质心,分配到离它最近的质心所在的那个簇.
  3. 因为初始的质心取值是随机的,所以结果可能不正确.所以,算法需要不停地更新来修正.修正的依据是步骤 2 分配完所有的数据后,将每个簇的质心的取值更新为整个簇的平均值.
  4. 2 ~ 3 是一个周期,整个算法需要多个周期才能完成.完成的条件是步骤 2 中,数据集中所有的点都依附在当前质心上,也可以理解为,每次对整个簇进行平均值时,得到的结果都是一样的.

k-mean flow

代码演练

依照前面的设想,假如我们手里有一份清单,清单里面是每个成员的信息,但为了演示的方便,我只保留了 2 个特征.

姓名身高颜值
frank177100
person117798
person217789

现在有 10 个成员,目标把他们分成 3 类,那么怎么编写代码?

我们用 Python 实现它.

生成 K 个随机质心

虽然说是随机,但我们应该都能理解这个随机也应该是有约束条件的.

比如下面的图,方块代表数据集合中的数据,圆形代表生成的质点,显然生成的质点就不合理.
centroid1

这里抛出一个问题,大家思考一下为什么这样不合理?

正确的质点应该在所有的数据集合中随机生成.

centroid2

我们可以让它在数据值中最小值和最大值中生成随机数.

def rand_centroid(dataset,k):n = dataset.shape[1]centroid = np.zeros((k,n))for i in range(n):min = np.min(dataset[:,i])max = np.max(dataset[:,i])stride = max - minprint(" min ",min," max ",max)centroid[:,i] = min + np.random.rand(k).T*stridereturn centroid

我们可以测试一下这个方法.

def test_centroid:dataset = np.array([[177,99],[169,80],[170,88],[190,86]],dtype=np.float)print(" centroid ",rand_centroid(dataset,2))

结果如下:

(' min ', 169.0, ' max ', 190.0)
(' min ', 80.0, ' max ', 99.0)
(' centroid ', array([[ 189.43140332,   87.76774545],[ 175.81451753,   89.8356134 ]]))

为数据集中每一个点分配簇

依照前面的流程图,这一步中有几个关键信息.

将数据集中的点分配到簇的过程当中,要计算它们到质心的距离.

在机器学习领域,距离的表示有许多种,这里采用欧氏距离.

把每个点分配到最近的质心,这就要求再建立一张专门的表,用来保存对应的信息.这张表分为 2 列,第一列是对应到数据集中每一行数据所分配到的簇的序号,第二列对应每个数据到质心的距离.

迭代与更新

每一轮簇的更新之后,需要将质心重新调整为整个簇的数据平均值,直到所有的点都不再需要更新.

前面有提到过,更新的终止条件是数据集中所有的点都依附在当前质心上,也可以理解为,每次对整个簇进行平均值时,得到的结果都是一样的.

def get_dist(A,B):return np.sqrt(np.sum(np.square(A - B)))def inference(dataset,k):record_n = dataset.shape[0]cluster_table = np.zeros((record_n,2))centroid = rand_centroid(dataset,k)print("centroid ",centroid)need_update = Truewhile need_update:need_update = Falsefor i in range(record_n):min_index = 0min_dist = np.inf;for j in range(k):dist = get_dist(dataset[i,:],centroid[j,:])print("dist ",dist)if dist < min_dist : min_dist = distmin_index = jprint('---------------------- index ',min_index)if cluster_table[i,0] != min_index:cluster_table[i,0] = min_indexneed_update = Truecluster_table[i,0] = min_indexcluster_table[i,1] = min_distprint("cluster_table ",cluster_table)for j in range(k):tmp = cluster_table[:,0] == jtmp = dataset[np.nonzero(tmp)[0]]# print ' cluster tmp ',tmpcentroid[j,:] = np.mean(tmp,axis=0)# need_update = Falsereturn cluster_table,centroid 

代码也许很抽象,希望大家结合前面的流程图来加深理解。

一定要搞懂 cluster_tablecentroid 这两张表的意义。

  1. cluster 每一行与 dataset 每一行对应
  2. centroid 存放了 k 个质心的特征。

最终结果如何,我们可以编写测试代码验证.

def test_kmeans():dataset = np.array([[177,99],[177,70],[180,72],[169,80],[170,88],[190,86]],dtype=np.float)# tmp = dataset[:,0] == 177# tmp = dataset > 170# print(" tmp ",np.nonzero(tmp)," ",np.nonzero(tmp)[0]," d ",dataset[np.nonzero(tmp)[0]])# print(" mean ",np.mean(dataset[np.nonzero(tmp)[0]],axis=0))print(inference(dataset,3))test_kmeans()

输出的结果如下.

(array([[ 2.        ,  9.19238816],[ 1.        ,  1.80277564],[ 1.        ,  1.80277564],[ 0.        ,  4.03112887],[ 0.        ,  4.03112887],[ 2.        ,  9.19238816]]), array([[ 169.5,   84. ],[ 178.5,   71. ],[ 183.5,   92.5]]))

当然,你可以用可视化的手段将算法过程直观表现出来,如下图:

在这里插入图片描述

上面的动画中,红色的就是质心。可以观察到,它是不停变动的。

其他 3 中颜色,代表三种类别的数据,数据会有少许变动,代表 K-Means 算法更新过程中,数据的分类发生了变化。

修改后的参考代码如下:

import matplotlib.pyplot as plt
import numpy as np
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimageglobal duration
global size
global fps
global imgdef rand_centroid(dataset,k):n = dataset.shape[1]centroid = np.zeros((k,n))for i in range(n):min = np.min(dataset[:,i])max = np.max(dataset[:,i])stride = max - minprint(" min ",min," max ",max)centroid[:,i] = min + np.random.rand(k).T*stridereturn centroiddef test_centroid():dataset = np.array([[177,99],[169,80],[170,88],[190,86]],dtype=np.float)print(" centroid ",rand_centroid(dataset,2))def get_dist(A,B):return np.sqrt(np.sum(np.square(A - B)))# return np.sqrt(np.sum(np.square(A))) - np.sqrt(np.sum(np.square(B)))def inference(dataset,k):record_n = dataset.shape[0]cluster_table = np.zeros((record_n,2))centroid = rand_centroid(dataset,k)print("centroid ",centroid)need_update = Truewhile need_update:need_update = Falsefor i in range(record_n):min_index = 0min_dist = np.inf;for j in range(k):dist = get_dist(dataset[i,:],centroid[j,:])print("dist ",dist)if dist < min_dist : min_dist = distmin_index = jprint('---------------------- index ',min_index)if cluster_table[i,0] != min_index:cluster_table[i,0] = min_indexneed_update = Truecluster_table[i,0] = min_indexcluster_table[i,1] = min_distprint("cluster_table ",cluster_table)for j in range(k):tmp = cluster_table[:,0] == jtmp = dataset[np.nonzero(tmp)[0]]# print ' cluster tmp ',tmpcentroid[j,:] = np.mean(tmp,axis=0)# need_update = Falsereturn cluster_table,centroid def make_frame(t):global durationglobal sizeglobal fpsglobal imgindex = int(t * fps) - 1if index == -1 : index = 0print(" index ",index," t ",t)return img[index]def inference_log(dataset,k):global imgimg = []record_n = dataset.shape[0]cluster_table = np.zeros((record_n,2))centroid = rand_centroid(dataset,k)print("centroid ",centroid)need_update = Trueduration = 2fig, ax = plt.subplots()color_list = ['m','y','c']marker_list = ['h','v','^']while need_update:need_update = Falsefor i in range(record_n):min_index = 0min_dist = np.inf;for j in range(k):dist = get_dist(dataset[i,:],centroid[j,:])print("dist ",dist)if dist < min_dist : min_dist = distmin_index = jprint('---------------------- index ',min_index)if cluster_table[i,0] != min_index:cluster_table[i,0] = min_indexneed_update = Truecluster_table[i,0] = min_indexcluster_table[i,1] = min_distprint("cluster_table ",cluster_table)ax.clear()for j in range(k):tmp = cluster_table[:,0] == jtmp = dataset[np.nonzero(tmp)[0]]ax.scatter(tmp[:,0],tmp[:,1],c=color_list[j],marker=marker_list[j])ax.set_ylim(60.0,100.0)# print ' cluster tmp ',tmpcentroid[j,:] = np.mean(tmp,axis=0)print(" dataset ",dataset[:,0].T)ax.scatter(centroid[:,0],centroid[:,1],c='r',lineWidth=3)img_figure = mplfig_to_npimage(fig)img.append(img_figure)# need_update = Falseplt.close()return cluster_table,centroid # a = np.array([1,2,3],dtype=np.float)
# b = np.array([4,7,9],dtype=np.float)# print(" a ",np.square(a),"sum ",np.sqrt(np.sum(np.square(a))))
# print ("dist ",get_dist(a,b))def test_kmeans():global durationglobal sizeglobal fpsglobal imgdataset = np.zeros((100,2),dtype=np.float)dataset[:,0] = np.random.rand(100)*40.0+150.0dataset[:,1] = np.random.randint(0,40,size=100)+60.0# dataset = np.array([[177,99],#             [177,70],#             [172,70],#             [173,79],#             [174,70],#             [175,70],#             [180,72],#             [169,80],#             [170,88],#             [190,86]],dtype=np.float)# tmp = dataset[:,0] == 177# tmp = dataset > 170# print(" tmp ",np.nonzero(tmp)," ",np.nonzero(tmp)[0]," d ",dataset[np.nonzero(tmp)[0]])# print(" mean ",np.mean(dataset[np.nonzero(tmp)[0]],axis=0))# print(inference(dataset,3))inference_log(dataset,3)fps = 2time_span = 1.0 / fpssize = img.__len__()duration = size * time_spanprint(" duration ",duration," time_span ",time_span," size ",size)animation = VideoClip(make_frame, duration=duration)animation.write_gif('kmeans.gif', fps=fps)img = Nonetest_kmeans()

思考

K-means 算法最关键的地方是什么?

我认为是距离,因为在机器学习领域内距离的表征有许多种,这里采用的是欧式距离.
但比如曼哈顿距离、余弦相似、马氏距离等等.

采用的距离表征不同,结果也不一样.

如何验证 K-means 结果?

K-means 运行结束后,但我们还得给这个算法进行评估,这样才能确保这个算法模型结果更精确.

因此,我们需要一些判断指标来验证 K-means 划分的每个簇是否合理.

SSE(误差的平方和)就是一种可用的判断表.

我们回到上面例子的输出结果,观察 cluster_table

array([[ 2.        ,  9.19238816],[ 1.        ,  1.80277564],[ 1.        ,  1.80277564],[ 0.        ,  4.03112887],[ 0.        ,  4.03112887]

数组的第 0 列是簇的序号,而第 1 列存储的是数据集中每个点到质心的距离,我们称之为误差,误差越小,结果就越准确.

把每个簇中所有的误差取平方,然后相加的结果就是 SSE.如果一个簇的 SSE 太大了,就说明这个簇的数据太散了,不够紧凑.

如果要加以改进的话,我们可以找出 SSE 最大的簇,然后运用 K-means 对它进行再次划分.这个时候 k 一般取值 2.

但这种操作之后,簇的个数会增加 1 个,怎么办呢?

我们可以将距离很近的两个质心所代表的簇合并.

K-means 的变种二分 K-means

由于原始的 K-means 不够理想,有人就提出了另外一种叫做二分 k-means 的算法.

思路如下:

  1. 将所有数据在开始时归为一个簇.
  2. 将这个簇一分为二.
  3. 找出 SSE 最大或者最小的簇,继续划分.直到簇的个数等于 K.

如何挑选一个簇进行划分呢?依据是划分之后总的 SSE 能够最大程度降低.

当然,也可以暴力一点,每次划分找 SSE 最大的那个簇.


http://www.ppmy.cn/news/130843.html

相关文章

K210学习记录(3)——kmodel生成与使用

0、引言 2022更新说明&#xff1a;这块芯片水太深&#xff0c;能不碰最好别碰&#xff0c;官方当时留的资料实在太少&#xff08;或者说我太菜&#xff09;。 如果要调用最新的nncase工具箱所支持的算子&#xff0c;最好采用嘉楠自家工具链VScode进行开发。不建议采用迦南官方…

机器学习(2): K-means (k均值) 聚类算法 小结

目录 1 聚类简介 2 k-means算法流程 3 利用k-means 对数据进行聚类 4 利用K-means进行图像分割 5 小结 参考资料 1 聚类简介 在无监督学习中&#xff0c;训练样本的标记信息是未知的&#xff0c;我们的目标是通过对无标记训练样本的学习来解释数据的内在性质及规律&…

S32K系列S32K144学习笔记——CAN

一用S32K144苦似海&#xff0c;道友&#xff0c;能不用&#xff0c;千万不去用。 本例程基以下如图所示接口操作&#xff0c;MCU为S32K144&#xff0c;开发平台S32DSworkspace 功能描述&#xff1a;CAN0通信 CAN0_EN–>PB15 如有错误&#xff0c;麻烦帮忙指出&#xff0c;谢…

Android 65K问题之65K来源探究

65K问题相信不少人都遇到过&#xff0c;65K即65536&#xff0c;关于这个值&#xff0c;是怎么来的&#xff1f;本文进行探究&#xff01; 1Unable to execute dex: method ID not in [0, 0xffff]: 65536PS:本文只是纯探索一下这个65K的来源&#xff0c;仅此而已。 到底是65k还是…

聚类分析(K-means算法)

1 聚类分析 1.1 相似度与距离度量1.2 聚类算法 及 划分方法 2 聚类模型评估&#xff08;优缺点&#xff09;3 K-means 在 sklearn方法4 确定K值–肘部法则–SSE5 模型评估指标–轮廓系数法–最近簇 5.1 轮廓系数5.2 最近簇定义—平均轮廓系数 [0,1]&#xff1a;5.3、Canopy算法…

K-means聚类算法

目录 简介 对距离的度量及SSE 问题及如何避免 k-means示例 k值选取 肘方法 可视化方法 TSNE 雷达图 总结一些对K-means聚类算法的理解。 K-means是一种聚类分析方法&#xff0c;聚类分析即是在没有给定划分类别的情况下&#xff0c;根据数据自身的相似度对样本数据进行…

K-Means中K值的选取

K-Means中K值的选择 &#xff08;1&#xff09;拍脑袋法&#xff08;2&#xff09;肘部法则&#xff08;Elbow Method&#xff09;&#xff08;3&#xff09;间隔统计量&#xff08;Gap Statistic&#xff09;&#xff08;4&#xff09;轮廓系数&#xff08;Silhouette Coeffic…

k近邻算法

本文来自的CSDN 博客 &#xff0c;全文地址请点击&#xff1a;https://blog.csdn.net/qq_35082030/article/details/60965320?utm_sourcecopy 0. 写在前面 在这一讲的讨论班中&#xff0c;我们将要讨论一下K近邻模型。可能有人会说&#xff0c;K近邻模型有什么好写的&#x…