学习日记_20241123_聚类方法(MeanShift)

server/2024/11/24 6:17:22/

前言

提醒:
文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。
其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展及意见建议,欢迎评论区讨论交流。

文章目录

  • 前言
  • 聚类算法
    • 经典应用场景
    • Mean Shift
      • 优点:
      • 缺点:
      • 总结:
      • 简单实例(函数库实现)
      • 数学表达
      • 总结
      • 手动实现
        • 代码分析


聚类算法

聚类算法在各种领域中有广泛的应用,主要用于发现数据中的自然分组和模式。以下是一些常见的应用场景以及每种算法的优缺点:

经典应用场景

  1. 市场细分:根据消费者的行为和特征,将他们分成不同的群体,以便进行有针对性的营销。

  2. 图像分割: 将图像划分为多个区域或对象,以便进行进一步的分析或处理。

  3. 社交网络分析:识别社交网络中的社区结构。

  4. 文档分类:自动将文档分组到不同的主题或类别中。

  5. 异常检测识别数据中的异常点或异常行为。

  6. 基因表达分析:在生物信息学中,根据基因表达模式对基因进行聚类

Mean Shift

Mean Shift 是一种基于密度的聚类方法,旨在寻找数据点的高密度区域,并将其聚集成簇。以下是 Mean Shift 的优缺点:

优点:

  1. 无参数: Mean Shift 不需要预先指定簇的数量,这使得它在处理未知数据时更具灵活性。
  2. 适应性: Mean Shift 根据数据的分布自适应地确定簇的形状和大小,能够处理不同密度的簇。
  3. 有效处理非球形簇: 它能够识别和聚类形状各异的簇,而不仅限于圆形或球形。
  4. 不受噪声影响: Mean Shift 对噪声的鲁棒性较强,这使得它在某些应用中表现良好。
  5. 可解释性: 由于该算法基于数据的密度估计,因此可以清晰地理解每个簇的形成过程。

缺点:

  1. 计算复杂度高: Mean Shift 的计算复杂度通常较高,尤其是在大数据集上,因为每次迭代都需要计算所有点之间的距离。
  2. 选择带宽参数: Mean Shift 的效果在很大程度上依赖于带宽参数的选择,选择不当可能导致不理想的聚类结果。带宽过小可能导致过拟合,而过大可能导致聚类的缺乏细节。
  3. 对高维数据敏感: 在高维空间中,密度估计变得更加困难,Mean Shift 的效果可能不如在低维空间中显著。
  4. 收敛速度: Mean Shift 的收敛速度可能在某些情况下较慢,特别是在数据点分布非常稀疏的情况下。
  5. 局部极小问题: 在某些情况下,Mean Shift 可能收敛到局部极小值,而不是全局最优聚类解决方案。

总结:

Mean Shift 是一种强大且灵活的聚类算法,尤其适合于处理具有复杂形状和不同密度的簇。然而,它的计算效率和对参数设置的敏感性可能会限制其在某些应用中的有效性。在选择使用 Mean Shift 时,需考虑数据的特性和应用场景,以决定其适用性。

简单实例(函数库实现)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import MeanShift# 生成样本数据
n_samples = 1000
X, y = make_blobs(n_samples=n_samples, centers=3, cluster_std=0.60, random_state=0)# 使用 Mean Shift 进行聚类
mean_shift = MeanShift(bandwidth=1.5)  # 选择带宽参数
mean_shift.fit(X)# 获取聚类标签和中心
labels = mean_shift.labels_
cluster_centers = mean_shift.cluster_centers_# 输出聚类结果
n_clusters = len(np.unique(labels))
print(f"聚类数量: {n_clusters}")# 绘制聚类结果
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=labels, s=30, cmap='viridis', marker='o', edgecolor='k')
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], c='red', s=200, alpha=0.75, marker='X')  # 聚类中心
plt.title('Mean Shift Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.grid()
plt.show()

data数据分布与代码运行结果:
在这里插入图片描述
在这里插入图片描述

数学表达

Mean Shift 是一种基于密度的聚类算法,其核心思想是通过迭代地向数据点的密度最大区域移动来进行聚类。以下是对 Mean Shift 方法的数学解释和公式推导。

  1. 核密度估计
    Mean Shift 的主要步骤是首先进行核密度估计(Kernel Density Estimation,KDE),以得到数据点的密度分布。给定数据集 X = { x 1 , x 2 , … , x n } X = \{x_1, x_2, \ldots, x_n\} X={x1,x2,,xn},我们可以使用核函数 K K K 来估计某一点 x x x 的密度:
    f ^ ( x ) = 1 n ∑ i = 1 n K ( x − x i h ) \hat{f}(x) = \frac{1}{n} \sum_{i=1}^{n} K\left(\frac{x - x_i}{h}\right) f^(x)=n1i=1nK(hxxi)
    其中:
    • f ^ ( x ) \hat{f}(x) f^(x) 是在点 x x x 的估计密度。
    • h h h 是带宽参数,控制核的宽度。
    • K K K 是核函数,常用的有高斯核、均匀核等。
  2. Mean Shift 迭代步骤
    Mean Shift 的目标是通过移动数据点来找到密度的极大值。对每个数据点 x x x,Mean Shift 更新的公式为:
    m h ( x ) = ∑ i = 1 n x i K ( x − x i h ) ∑ i = 1 n K ( x − x i h ) m_h(x) = \frac{\sum_{i=1}^{n} x_i K\left(\frac{x - x_i}{h}\right)}{\sum_{i=1}^{n} K\left(\frac{x - x_i}{h}\right)} mh(x)=i=1nK(hxxi)i=1nxiK(hxxi)
    其中:
    • m h ( x ) m_h(x) mh(x) 是在 x x x 位置的 Mean Shift 矢量,表示对 x x x 的更新。
    • 该公式表示在当前位置 x x x 周围的所有点 x i x_i xi 的加权平均。
  3. 更新规则
    Mean Shift 的更新过程可以概括如下:
    1. 计算当前点 x x x 的密度加权平均位置 m h ( x ) m_h(x) mh(x)
    2. 更新 x x x m h ( x ) m_h(x) mh(x)
    3. 重复步骤 1 和 2,直到 x x x 的变化小于某个预设的阈值(即收敛)。
  4. 收敛和聚类
    通过不断更新,每个点将最终收敛到一个高密度区域(即簇的中心)。在所有点收敛后,可以通过其聚类标签来区分不同的簇,通常使用每个点最终对应的中心位置来标识。
  5. 聚类结果
    最终的聚类结果基于每个点所收敛的中心位置。相同的聚类中心将被标记为同一簇。聚类的数量不是预先指定的,而是在算法运行后自动确定的。

总结

Mean Shift 聚类通过不断地向数据点密度最大化的方向移动,利用核密度估计来找到数据的高密度区域。其数学基础主要依赖于核函数和加权平均的概念,使其能够灵活地适应不同密度和形状的簇。选择合适的带宽参数 h h h 是影响 Mean Shift 效果的重要因素。

手动实现

import numpy as npdef gaussian_kernel(distance, bandwidth):"""高斯核函数。参数:- distance: 距离值- bandwidth: 带宽参数返回:- 核函数的值"""return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (distance / bandwidth) ** 2)def mean_shift(data, bandwidth, max_iter=300, tol=1e-3):"""Mean Shift 聚类算法实现。参数:- data: 输入数据点,形状为 (n_samples, n_features)- bandwidth: 带宽参数,控制核的范围- max_iter: 最大迭代次数- tol: 收敛判定的容忍度返回:- centers: 聚类中心- labels: 数据点的标签"""# 初始化:所有点都是初始聚类中心centers = np.copy(data)for it in range(max_iter):new_centers = []for center in centers:# 计算到其他点的距离distances = np.linalg.norm(data - center, axis=1)# 计算核函数权重weights = gaussian_kernel(distances, bandwidth)# 更新中心点位置new_center = np.sum(data * weights[:, np.newaxis], axis=0) / np.sum(weights)new_centers.append(new_center)# 检查是否收敛new_centers = np.array(new_centers)shift_distance = np.linalg.norm(new_centers - centers, axis=1).max()centers = new_centersif shift_distance < tol:break# 将靠近的中心合并unique_centers = []for center in centers:if not any(np.linalg.norm(center - unique_center) < bandwidth / 2 for unique_center in unique_centers):unique_centers.append(center)unique_centers = np.array(unique_centers)# 为每个点分配标签labels = np.zeros(len(data), dtype=int)for i, point in enumerate(data):distances = np.linalg.norm(unique_centers - point, axis=1)labels[i] = np.argmin(distances)return unique_centers, labels# 测试代码
if __name__ == "__main__":# 生成测试数据np.random.seed(0)data = np.vstack((np.random.normal(loc=0, scale=1, size=(50, 2)),np.random.normal(loc=5, scale=1, size=(50, 2)),np.random.normal(loc=10, scale=1, size=(50, 2))))# 调用 Mean Shiftbandwidth = 2  # 带宽参数centers, labels = mean_shift(data, bandwidth)# 可视化import matplotlib.pyplot as pltplt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', s=30)plt.scatter(centers[:, 0], centers[:, 1], c='red', marker='x', s=100, label='Centers')plt.legend()plt.title("Mean Shift Clustering")plt.show()

数据与结果为
在这里插入图片描述
在这里插入图片描述

代码分析
def gaussian_kernel(distance, bandwidth):"""高斯核函数。参数:- distance: 距离值- bandwidth: 带宽参数返回:- 核函数的值"""return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (distance / bandwidth) ** 2)

代码的数学表达式可以通过以下公式表示:
K ( x ) = 1 σ 2 π exp ⁡ ( − x 2 2 σ 2 ) K(x) = \frac{1}{\sigma \sqrt{2\pi}} \exp\left(-\frac{x^2}{2\sigma^2}\right) K(x)=σ2π 1exp(2σ2x2)
其中:

  • K ( x ) K(x) K(x) 是高斯核函数的值。
  • σ \sigma σ 是带宽参数 h h h(在此处带入 bandwidth)。
  • x x x 是距离值(在此处带入 distance)。
    将其转换为数学表达式,我们可以写成:
    K ( distance ) = 1 bandwidth ⋅ 2 π exp ⁡ ( − 1 2 ( distance bandwidth ) 2 ) K(\text{distance}) = \frac{1}{\text{bandwidth} \cdot \sqrt{2\pi}} \exp\left(-\frac{1}{2} \left(\frac{\text{distance}}{\text{bandwidth}}\right)^2\right) K(distance)=bandwidth2π 1exp(21(bandwidthdistance)2)
    这表达了在给定距离和带宽参数的情况下,高斯核函数的值。它描述了如何根据距离值和带宽来计算核函数的输出,形成一个用于加权的函数。
def mean_shift(data, bandwidth, max_iter=300, tol=1e-3):"""Mean Shift 聚类算法实现。参数:- data: 输入数据点,形状为 (n_samples, n_features)- bandwidth: 带宽参数,控制核的范围- max_iter: 最大迭代次数- tol: 收敛判定的容忍度返回:- centers: 聚类中心- labels: 数据点的标签"""# 初始化:所有点都是初始聚类中心centers = np.copy(data)for it in range(max_iter):new_centers = []for center in centers:# 计算到其他点的距离distances = np.linalg.norm(data - center, axis=1)# 计算核函数权重weights = gaussian_kernel(distances, bandwidth)# 更新中心点位置new_center = np.sum(data * weights[:, np.newaxis], axis=0) / np.sum(weights)new_centers.append(new_center)# 检查是否收敛new_centers = np.array(new_centers)shift_distance = np.linalg.norm(new_centers - centers, axis=1).max()centers = new_centersif shift_distance < tol:break# 将靠近的中心合并unique_centers = []for center in centers:if not any(np.linalg.norm(center - unique_center) < bandwidth / 2 for unique_center in unique_centers):unique_centers.append(center)unique_centers = np.array(unique_centers)# 为每个点分配标签labels = np.zeros(len(data), dtype=int)for i, point in enumerate(data):distances = np.linalg.norm(unique_centers - point, axis=1)labels[i] = np.argmin(distances)return unique_centers, labels
  1. centers = np.copy(data)
    给定数据集 { x 1 , x 2 , … , x n } ⊂ R d \{x_1, x_2, \ldots, x_n\} \subset \mathbb{R}^d {x1,x2,,xn}Rd
    初始聚类中心设定为: C = { c 1 , c 2 , … , c n } = { x 1 , x 2 , … , x n } C = \{c_1, c_2, \ldots, c_n\} = \{x_1, x_2, \ldots, x_n\} C={c1,c2,,cn}={x1,x2,,xn}
  2. centers = np.copy(data)
    for it in range(max_iter):
    1. 权重计算(核函数权重):
    • 对每个点 x i ∈ R d x_i \in \mathbb{R}^d xiRd,计算相对于当前中心 c j c_j cj 的距离: d ( x i , c j ) = ∥ x i − c j ∥ d(x_i, c_j) = \|x_i - c_j\| d(xi,cj)=xicj
    • 使用核函数计算权重,通常使用高斯核:
      K ( x ) = exp ⁡ ( − ∥ x ∥ 2 2 h 2 ) K(x) = \exp\left(-\frac{\|x\|^2}{2h^2}\right) K(x)=exp(2h2x2)
      其中, h h h 是带宽参数。
    1. 更新聚类中心
      对每个中心 c j c_j cj,根据核权重来更新:
      c j new = ∑ i = 1 n K ( x i − c j ) ⋅ x i ∑ i = 1 n K ( x i − c j ) c_j^{\text{new}} = \frac{\sum_{i=1}^n K(x_i - c_j) \cdot x_i}{\sum_{i=1}^n K(x_i - c_j)} cjnew=i=1nK(xicj)i=1nK(xicj)xi
    2. 收敛判定
      计算所有中心的最大移动距离:
      shift_distance = max ⁡ j ∥ c j new − c j ∥ \text{shift\_distance} = \max_j \|c_j^{\text{new}} - c_j\| shift_distance=jmaxcjnewcj
      如果 shift_distance < tol \text{shift\_distance} < \text{tol} shift_distance<tol,则算法收敛,否则继续迭代。
  3. unique_centers = []
    for center in centers:
    合并近似中心
    如果两个中心 c i c_i ci c j c_j cj 的距离小于 h 2 \frac{h}{2} 2h,则合并这两个中心。
  4. labels = np.zeros(len(data), dtype=int)
    for i, point in enumerate(data):
    分配标签
    对于每个数据点 x i x_i xi,计算到每个中心的距离,并分配到最近的中心:
    label i = arg ⁡ min ⁡ k ∥ x i − unique_center k ∥ \text{label}_i = \arg\min_k \|x_i - \text{unique\_center}_k\| labeli=argkminxiunique_centerk

http://www.ppmy.cn/server/144461.html

相关文章

【深入学习大模型之:微调 GPT 使其自动生成测试用例及自动化用例】

深入学习大模型之&#xff1a;微调 GPT 使其自动生成测试用例及自动化用例 1. 自动生成测试用例目标训练过程代码示范 2. 自动写自动化代码目标训练过程代码示范可能的输出 3. 自动生成文本小说目标训练过程代码示范输出示例 4. 总结 1. 自动生成测试用例 目标 训练一个大语言…

什么是C++中的模板特化和偏特化?

C中的模板特化和偏特化是模板编程中的重要概念&#xff0c;它们允许程序员为特定类型或条件提供更具体的实现。 模板特化 模板特化是指为特定类型提供一个明确的实现&#xff0c;从而覆盖普通模板的通用实现。这通常在模板类或函数对某个特定类型的处理方式需要不同于一般类型…

HARCT 2025 新增分论坛6:基于机器人的智能处理控制

会议名称&#xff1a;机电液一体化与先进机器人控制技术国际会议 会议简称&#xff1a;HARCT 2025 大会时间&#xff1a;2025年1月3日-6日 大会地点&#xff1a;中国桂林 主办单位&#xff1a;桂林航天工业学院、广西大学、桂林电子科技大学、桂林理工大学 协办单位&#…

【安卓脚本】Android工程中文硬编码抽取

【安卓脚本】Android工程中文硬编码抽取 Android 原生工程 中文硬编码抽取功能支持流程示意项目地址 Android 原生工程 中文硬编码抽取 安卓在进行国际化多语言功能时经常会遇到一个头疼的问题&#xff0c;就是在以往的项目中往往存在大量的中文硬编码&#xff0c;这块人工处理…

uniapp奇怪bug汇总

H5端请求api文件夹接口报错 踩坑指数&#xff1a;5星 小程序、APP之前都是用api文件夹的接口引用调用&#xff0c;在h5端启动时报错&#xff0c;研究半天&#xff0c;发现把api文件夹名字改成apis就能调用&#xff0c;就像是关键字一样无法使用。 import authApi from /api/…

算力100问☞第16问:什么是TPU?

TPU全称是Tensor Processing Unit芯片&#xff0c;中文全称是张量处理单元芯片&#xff0c;是谷歌开发的一种特殊类型的芯片&#xff0c;用于加速人工智能&#xff08;AI&#xff09;和机器学习&#xff08;ML&#xff09;工作负载。TPU主要针对张量&#xff08;tensor&#xf…

Unity开发抖音小游戏使用长音频和短音频

抖音小游戏使用长音频和短音频 介绍WebGL对Unity音频的限制优化建议Iphone静音不同策略Unity中播放长音频无法播放可以使用以下方法总结 介绍 最近好久没有更新文章了&#xff0c;最近在研究抖音小程序也在帮公司做抖音小游戏这块&#xff0c;正好之前遇到了一个比较困扰的问题…

【H2O2|全栈】JS进阶知识(八)ES6(4)

目录 前言 开篇语 准备工作 浅拷贝和深拷贝 浅拷贝 概念 常见方法 弊端 案例 深拷贝 概念 常见方法 弊端 逐层拷贝 原型 构造函数 概念 形式 成员 弊端 显式原型和隐式原型 概念 形式 constructor 概念 形式 原型链 概念 形式 结束语 前言 开篇语…