通俗易懂讲透均值漂移(Mean Shift)聚类算法
摘要: Mean Shift是一种无需预设簇数的密度聚类算法,通过让数据点向高密度区域漂移实现自动分组。其核心思想是设置带宽参数(h),计算每个点邻域内的加权均值并迭代移动,直到收敛形成聚类。算法优点包括适应任意形状簇、抗噪声,但计算复杂度高且对带宽敏感。适用于图像分割、目标跟踪等场景,尤其适合不规则分布的小规模数据。文中通过糖果分组案例、公式推导和Python代码(含自动带宽估算与可视化)详细讲
通俗易懂讲透均值漂移(Mean Shift)聚类算法
不用指定簇数、自动找高密度区域,这是Mean Shift最香的特点!本文用大白话+生活案例+公式详解+可直接运行代码,本科生、研究生都能轻松看懂。
一、均值漂移是什么?一句话讲明白
均值漂移(Mean Shift)是基于密度的无监督聚类算法。
它的核心逻辑:
让每个数据点往“周围点最多、最拥挤”的地方移动,直到不再动,最后聚在一起的就是一类。
你可以把它理解为:
- 数据点 = 水面上的小船
- 每只船都往人最多的湖心漂
- 最后所有船自动聚成几堆 → 这就是聚类结果
二、超通俗生活案例:给糖果自动分组
桌上有6颗糖果,坐标如下:(1,2)、(2,2)、(3,2)、(8,8)、(9,8)、(9,9)
我们用Mean Shift自动分组:
步骤1:设置窗口半径 bandwith
相当于“看多远”,这里设 h=3。
步骤2:每个点找周围邻居,算平均位置
- 左边3个点邻居互相重叠,平均位置是 (2,2)
- 右边3个点邻居互相重叠,平均位置约 (8.67,8.33)
步骤3:不断往平均位置漂移
所有点慢慢向这两个中心靠拢。
步骤4:收敛停止
最终自动分成2类,完全不用提前告诉算法分几组!
三、Mean Shift 标准算法流程(必背)
- 选带宽h:确定每个点观察的范围大小
- 对每个点:在窗口内算加权均值(密度中心)
- 漂移:把点移到这个均值位置
- 迭代:重复计算→移动,直到点几乎不动
- 合并中心:靠得很近的中心算同一个簇
四、核心公式(报告/作业直接用)
1. 核密度估计(算密度)
f(x)=1nhd∑i=1nK(∥x−xi∥h) f(x)=\frac{1}{n h^{d}} \sum_{i=1}^{n} K\left(\frac{\left\| x-x_{i}\right\| }{h}\right) f(x)=nhd1i=1∑nK(h∥x−xi∥)
- h:带宽(搜索半径)
- K:核函数(常用高斯核)
- 含义:x点周围有多“拥挤”
2. 加权均值(漂移目标点)
μ(x)=∑xi⋅K(∥x−xi∥h)∑K(∥x−xi∥h) \mu(x)=\frac{\sum x_{i} \cdot K\left(\frac{\left\|x-x_{i}\right\|}{h}\right)}{\sum K\left(\frac{\left\|x-x_{i}\right\|}{h}\right)} μ(x)=∑K(h∥x−xi∥)∑xi⋅K(h∥x−xi∥)
3. 均值漂移向量(移动方向)
m(x)=μ(x)−x m(x) = \mu(x) - x m(x)=μ(x)−x
4. 点更新规则
xnew=xold+m(x) x_{new} = x_{old} + m(x) xnew=xold+m(x)
五、关键参数:带宽 bandwidth
带宽h是Mean Shift唯一重要参数:
- h太大:簇变少,可能把不同类合并
- h太小:簇变多,过度分割
- 实战技巧:用
estimate_bandwidth()自动估算
六、完整实战代码(可直接复制运行)
用Mean Shift做二维数据聚类+密度图可视化:
# 安装依赖
# pip install numpy matplotlib scikit-learn seaborn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
# 1. 生成测试数据
centers = [[2, 3], [8, 8], [1, 10], [9, 1]]
cluster_std = [1.0, 0.8, 1.5, 0.5]
X, _ = make_blobs(n_samples=3000, centers=centers,
cluster_std=cluster_std, random_state=42)
# 2. 自动计算最佳带宽
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
print(f"自动估算带宽:{bandwidth:.2f}")
# 3. 均值漂移聚类
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
n_clusters = len(np.unique(labels))
print(f"自动聚类数量:{n_clusters}")
# 4. 绘制聚类结果
plt.figure(figsize=(14, 7))
colors = sns.color_palette("bright", n_colors=n_clusters)
for k, col in zip(np.unique(labels), colors):
mask = labels == k
plt.scatter(X[mask, 0], X[mask, 1], c=[col], s=30, alpha=0.7)
plt.scatter(cluster_centers[:,0], cluster_centers[:,1],
c='black', s=300, marker='X', label='Cluster Center')
plt.title(f"Mean Shift 聚类结果(簇数={n_clusters})", fontsize=16)
plt.legend()
plt.grid(True)
plt.show()
# 5. 绘制密度分布图
plt.figure(figsize=(14,7))
sns.kdeplot(x=X[:,0], y=X[:,1], fill=True, cmap="viridis", alpha=0.8)
plt.scatter(cluster_centers[:,0], cluster_centers[:,1],
c='red', s=300, marker='X', label='Center')
plt.title("数据密度分布与聚类中心", fontsize=16)
plt.legend()
plt.grid(True)
plt.show()
代码亮点
- 自动估算带宽,不用手动调参
- 自动输出簇数
- 双图展示:聚类结果+密度分布
- 适配作业、报告、博客
七、Mean Shift 优缺点(面试/报告必写)
✅ 优点
- 不用指定K:算法自动找簇数
- 支持任意形状:环形、弯曲、不规则簇都能分
- 基于密度:对噪声有一定鲁棒性
- 原理直观:漂移→收敛,容易理解
❌ 缺点
- 速度慢:复杂度O(n²),大数据不友好
- 对带宽h极度敏感
- 高维数据效果差
- 可能收敛到局部最优
八、适用场景(什么时候用它?)
👉 首选Mean Shift
- 不知道要分几类
- 数据形状不规则、密度不均匀
- 图像分割、目标跟踪、视频目标检测
- 小规模数据、需要自动化分组
👉 不推荐用
- 百万级大数据 → 用Mini-Batch K-Means
- 高维稀疏数据 → 用谱聚类
- 追求极致速度 → 用K-Means
九、一句话总结
均值漂移(Mean Shift)是不用设簇数、自动找高密度区的聚类算法,靠“往人多的地方漂”实现分组,特别适合不规则分布与图像类任务。
更多推荐
所有评论(0)