京公网安备 11010802034615号
经营许可证编号:京B2-20210330
简单易学的机器学习算法—Mean Shift聚类算法
一、Mean Shift算法概述
Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:
定义了核函数;
增加了权重系数。
核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。
二、Mean Shift算法的核心原理
2.1、核函数
在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:

并且满足:
(1)、k是非负的
(2)、k是非增的
(3)、k是分段连续的
那么,函数K(x)就称为核函数。
常用的核函数有高斯核函数。高斯核函数如下所示:
其中,h称为带宽(bandwidth),不同带宽的核函数如下图所示:

上图的画图脚本如下所示:
'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in xrange(-40,40):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
score_4 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
score_4.append(cal_Gaussian(i,4))
plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
2.2、Mean Shift算法的核心思想
2.2.1、基本原理
对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):
步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

步骤2:移动该点到偏移均值点处

步骤3: 重复上述的过程(计算新的偏移均值,移动)

步骤4:满足了最终的条件,即退出

从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。
2.2.2、基本的Mean Shift向量形式
对于给定的d维空间Rd中的n个样本点
,则对于x点,其Mean Shift向量的基本形式为:
其中,Sh指的是一个半径为h的高维球区域,如上图中的蓝色的圆形区域。Sh的定义为:
这样的一种基本的Mean Shift形式存在一个问题:在Sh的区域内,每一个点对x的贡献是一样的。而实际上,这种贡献与x到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。
2.2.3、改进的Mean Shift向量形式
基于以上的考虑,对基本的Mean Shift向量形式中增加核函数和样本权重,得到如下的改进的Mean Shift向量形式:
其中:
G(x)是一个单位的核函数。H是一个正定的对称d×d矩阵,称为带宽矩阵,其是一个对角阵。w(xi)⩾0是每一个样本的权重。对角阵H的形式为:

上述的Mean Shift向量可以改写成:

Mean Shift向量Mh(x)是归一化的概率密度梯度。
2.3、Mean Shift算法的解释
在Mean Shift算法中,实际上是利用了概率密度,求得概率密度的局部最优解。
2.3.1、概率密度梯度
对一个概率密度函数f(x),已知d维空间中n个采样点xi,i=1,⋯,n,f(x)的核函数估计(也称为Parzen窗估计)为:

其中
w(xi)⩾0是一个赋给采样点xi的权重
K(x)是一个核函数
概率密度函数f(x)的梯度▽f(x)的估计为
令
,则有:
其中,第二个方括号中的就是Mean Shift向量,其与概率密度梯度成正比。
2.3.2、Mean Shift向量的修正

Mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi)−x
记:
,则上式变成:
Mh(x)=mh(x)+x
这与梯度上升的过程一致。
2.4、Mean Shift算法流程
Mean Shift算法的算法流程如下:
计算mh(x)
令x=mh(x)
如果∥mh(x)−x∥<ε,结束循环,否则,重复上述步骤
三、实验
3.1、实验数据
实验数据如下图所示(来自参考文献1):

画图的代码如下:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
3.2、实验的源码
#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''
import math
import sys
import numpy as np
MIN_DISTANCE = 0.000001#mini error
def load_data(path, feature_num=2):
f = open(path)
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num:
continue
for i in xrange(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close()
return data
def gaussian_kernel(distance, bandwidth):
m = np.shape(distance)[0]
right = np.mat(np.zeros((m, 1)))
for i in xrange(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
points = np.mat(points)
m,n = np.shape(points)
#计算距离
point_distances = np.mat(np.zeros((m,1)))
for i in xrange(m):
point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)
#计算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth)
#计算分母
all = 0.0
for i in xrange(m):
all += point_weights[i, 0]
#均值偏移
point_shifted = point_weights.T * points / all
return point_shifted
def euclidean_dist(pointA, pointB):
#计算pointA和pointB之间的欧式距离
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total)
def distance_to_group(point, group):
min_distance = 10000.0
for pt in group:
dist = euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
def group_points(mean_shift_points):
group_assignment = []
m,n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
print item_1
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print "iter : " + str(iter)
for i in range(0, m):
#判断每一个样本点是否需要计算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#计算最终的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
if __name__ == "__main__":
#导入数据集
path = "./data"
data = load_data(path, 2)
#训练,h=2
points, shift_points, cluster = train_mean_shift(data, 2)
for i in xrange(len(cluster)):
print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
3.3、实验的结果
经过Mean Shift算法聚类后的数据如下所示:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 3:
label = int(lines[2])
if label == 0:
data_1 = lines[0].strip().split(",")
cluster_x_0.append(float(data_1[0]))
cluster_y_0.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
elif label == 1:
data_1 = lines[0].strip().split(",")
cluster_x_1.append(float(data_1[0]))
cluster_y_1.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
else:
data_1 = lines[0].strip().split(",")
cluster_x_2.append(float(data_1[0]))
cluster_y_2.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
f.close()
plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')数据分析师培训
#plt.legend(loc="best")
plt.show()
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
【核心关键词】贷款、报表、课程、专业、建模、缺失值、营销、互联网、银行、办公自动化、数据分析、数据预处理、特征工程、贷 ...
2026-06-05在数据库数据查询、业务报表统计、多表关联分析中,LEFT JOIN左连接是使用率最高的SQL关联查询语句。其核心特性是保留左表全部数 ...
2026-06-05 很多数据分析师能熟练地写SQL、做透视表、算描述性统计,但当被问到“如何预测用户流失概率”“如何归因销量下滑的关键因素 ...
2026-06-05任何一款产品从诞生、普及到最终退出市场,都会遵循一套固定的发展规律,这就是产品生命周期理论。在市场竞争日益激烈、产品迭代 ...
2026-06-04在Excel数据分析、办公统计、业务报表制作场景中,数据透视表是数据汇总、分类统计、快速复盘的核心工具,能够高效完成海量原始 ...
2026-06-04 很多数据分析师拿到数据就开始清洗、建模,但当被问到“这批数据属于什么类型——结构化还是非结构化?分类变量还是数值变量 ...
2026-06-04在问卷调查与社会科学数据分析中,卡方检验是最常用、最基础的非参数检验方法,广泛应用于市场调研、用户分析、行为统计、满意度 ...
2026-06-03【核心关键词】贷款、报表、课程、专业、建模、缺失值、营销、互联网、银行、办公自动化、数据分析、数据预处理、特征工程、贷 ...
2026-06-03 很多数据分析师画过趋势图、做过业绩预测,但当被问到“这个月销售额增长20%,到底是长期趋势自然增长,还是促销活动的短期 ...
2026-06-03逻辑回归是数据分析、机器学习、统计建模中应用最广泛的二分类预测模型,常用于风险判断、行为预测、归因分析等场景。在SPSS、Py ...
2026-06-02数字经济时代,市场竞争日趋同质化,用户消费需求愈发个性化、多元化,传统依托经验、粗放式、广撒网的营销模式弊端日益凸显。长 ...
2026-06-02 很多数据分析师做过按月份的销售额趋势图,画过按天的流量折线图,但当被问到“时间序列和普通数据有什么本质区别”“季节性 ...
2026-06-02在市场竞争日趋饱和、用户需求不断细分的当下,企业创业创新、产品迭代与市场拓展不再依赖经验决策,而是需要系统化、工具化的商 ...
2026-06-01【核心关键词】调度、岗位、数据库、企业、报表、培训、程序、数据分析、数据加工、业务部门、企业数据、调度工具、业务指标、 ...
2026-06-01 很多数据分析师能熟练地计算指标、搭建标签体系,但当被问到“画像到底在解决什么问题”“画像和标签是什么关系”“画像如何 ...
2026-06-01在数据统计分析、数据清洗、异常值识别与数据分布研究中,箱型图是最直观、高效、专业的可视化分析工具。相较于柱状图、折线图仅 ...
2026-05-29Tkinter是Python内置的标准GUI图形界面库,具备无需额外安装、调用简单、兼容性强、轻量化高效等优势,是Python快速开发桌面小程 ...
2026-05-29 很多分析师在设计标签时思路清晰,但真到落地环节却面临“数据在手,不知如何转化为可用标签”的困境:或因加工方式选择不当 ...
2026-05-29【核心关键词】大数据、经理、专业、金融、客户、传统、建模、数据产品、互联网金融、产品经理、数据分析、金融行业、数据模型 ...
2026-05-28 很多分析师每天和数据打交道,但当被问到“标签是什么”“标签和指标有什么区别”“标签体系如何设计”时,却常常答不上来。 ...
2026-05-28