京公网安备 11010802034615号
经营许可证编号:京B2-20210330
机器学习之k-近邻(kNN)算法与Python实现
k-近邻算法(kNN,k-NearestNeighbor),是最简单的机器学习分类算法之一,其核心思想在于用距离目标最近的k个样本数据的分类来代表目标的分类(这k个样本数据和目标数据最为相似)。
一 k-近邻(kNN)算法概述
1.概念
kNN算法的核心思想是用距离最近的k个样本数据的分类来代表目标数据的分类。
其原理具体地讲,存在一个训练样本集,这个数据训练样本的数据集合中的每个样本都包含数据的特征和目标变量(即分类值),输入新的不含目标变量的数据,将该数据的特征与训练样本集中每一个样本进行比较,找到最相似的k个数据,这k个数据出席那次数最多的分类,即输入的具有特征值的数据的分类。
例如,训练样本集中包含一系列数据,这个数据包括样本空间位置(特征)和分类信息(即目标变量,属于红色三角形还是蓝色正方形),要对中心的绿色数据的分类。运用kNN算法思想,距离最近的k个样本的分类来代表测试数据的分类,那么:
当k=3时,距离最近的3个样本在实线内,具有2个红色三角和1个蓝色正方形**,因此将它归为红色三角。
当k=5时,距离最近的5个样本在虚线内,具有2个红色三角和3个蓝色正方形**,因此将它归为蓝色正方形。
2.特点
优点
(1)监督学习:可以看到,kNN算法首先需要一个训练样本集,这个集合中含有分类信息,因此它属于监督学习。
(2)通过计算距离来衡量样本之间相似度,算法简单,易于理解和实现。
(3)对异常值不敏感
缺点 (4)需要设定k值,结果会受到k值的影响,通过上面的例子可以看到,不同的k值,最后得到的分类结果不尽相同。k一般不超过20。(5)计算量大,需要计算样本集中每个样本的距离,才能得到k个最近的数据样本。 (6)训练样本集不平衡导致结果不准确问题。当样本集中主要是某个分类,该分类数量太大,导致近邻的k个样本总是该类,而不接近目标分类。
3.kNN算法流程
一般情况下,kNN有如下流程:
(1)收集数据:确定训练样本集合测试数据;
(2)计算测试数据和训练样本集中每个样本数据的距离;
常用的距离计算公式:
(3)按照距离递增的顺序排序;
(4)选取距离最近的k个点;
(5)确定这k个点中分类信息的频率;
(6)返回前k个点中出现频率最高的分类,作为当前测试数据的分类。二 、Python算法实现
1.KNN算法分类器
建立一个名为“KNN.py”的文件,构造一个kNN算法分类器的函数:
from numpy import *
import operator
#定义KNN算法分类器函数
#函数参数包括:(测试数据,训练数据,分类,k值)
def classify(inX,dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5 #计算欧式距离
sortedDistIndicies=distances.argsort() #排序并返回index
#选择距离最近的k个值
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]]
#D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
#排序
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
在KNN.py中定义一个生成“训练样本集”的函数:
#定义一个生成“训练样本集”的函数,包含特征和分类信息在Python控制台先将当前目录设置为“KNN.py”所在的文件目录,将测试数据[0,0]进行KNN算法分类测试,输入:
import KNN
#生成训练样本
group,labels=KNN.createDataSet()
#对测试数据[0,0]进行KNN算法分类测试
KNN.classify([0,0],group,labels,3)
Out[3]: 'B'
可以看到该分类器函数将[0,0]分类为B组,符合实际情况,分入了符合逻辑的正确的类别。但如何知道KNN分类的正确性呢?
2.kNN算法用于约会网站配对
2.1准备数据
该数据在文本文件datingTestSet2.txt中,该数据具有1000行,4列,分别是特征数据(每年获得的飞行常客里程数,玩视频游戏所耗时间百分比,每周消费的冰淇淋公升数),和目标变量/分类数据(是否喜欢(1表示不喜欢,2表示魅力一般,3表示极具魅力)),部分数据展示如下:
完整地数据下载地址如下:
约会网站测试数据
(1)将文本记录转为成numpy
在python控制台输入:
in [5]:datingDataMat,datingLabels=KNN.file2matrix('G:\Workspaces\MachineLearning\machinelearninginaction\Ch02\datingTestSet2.txt')#括号是文件路径
(2)可视化分析数据
运用Matplotlib创建散点图来分析数据:
import matplotlib
import matplotlib.pyplot as plt
#对第二列和第三列数据进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],c=datingLabels)
plt.xlabel('Percentage of Time Spent Playing Video Games')
plt.ylabel('Liters of Ice Cream Consumed Per Week')
#对第一列和第二列进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],c=datingLabels)
plt.xlabel('Miles of plane Per year')
plt.ylabel('Percentage of Time Spent Playing Video Games')
ax.legend(loc='best')

(3)数据归一化
由于不同的数据在大小上差别较大,在计算欧式距离,整体较大的数据明细所占的比重更高,因此需要对数据进行归一化处理。
在Python控制台输入:
reload(KNN)数据的准备工作完成,下一步对算法进行测试。
2.2 算法测试
kNN算法分类的结果的效果,可以使用正确率/错误率来衡量,错误率为0,则表示分类很完美,如果错误率为1,表示分类完全错误。我们使用1000条数据中的90%作为训练样本集,其中的10%来测试错误率。
#定义测试算法的函数在控制台输入命令来测试错误率:
reload(KNN)
Out[150]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.datingClassTest()
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
the classifier came back with: 3,the real answer is: 1
the total error rate is : 0.050000
可以看到KNN算法分类器处理约会数据的错误率是5%,具有较高额正确率。
可以在datingClassTest函数中传入参数h来改变测试数据比例,来看修改后Ration后错误率有什么样的变化。
KNN.datingClassTest(0.2)
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the total error rate is : 0.080000
减小训练样本集数据,增加测试数据,错误率增加到8%。
2.3 使用KNN算法进行预测
def classifypersion():测试一下:
reload(KNN)
Out[153]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.classifypersion()
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice creamconsued per year?0.5
You will probably like this persion :not at all
3. KNN算法用于手写识别系统
已经将图片转化为32*32 的文本格式,文本格式如下:
00000000000111110000000000000000
00000000001111111000000000000000
00000000011111111100000000000000
00000000111111111110000000000000
00000001111111111111000000000000
00000011111110111111100000000000
00000011111100011111110000000000
00000011111100001111110000000000
00000111111100000111111000000000
00000111111100000011111000000000
00000011111100000001111110000000
00000111111100000000111111000000
00000111111000000000011111000000
00000111111000000000011111100000
00000111111000000000011111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000011111000000000001111100000
00000011111100000000011111100000
00000011111100000000111111000000
00000001111110000000111111100000
00000000111110000001111111000000
00000000111110000011111110000000
00000000111111000111111100000000
00000000111111111111111000000000
00000000111111111111110000000000
00000000011111111111100000000000
00000000001111111111000000000000
00000000000111111110000000000000
3.1数据准备
(1)将32*32的文本格式转为成1*2014的向量
在控制台中输入命令测试下函数:
reload(KNN)
3.2 算法测试
使用kNN算法测试手写数字识别
#引入os模块的listdir函数,列出给定目录的文件名
from os impor listdir
def handwritingClassTest():
hwLabels=[]
trainingFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits')#列出文件名
m=len(trainingFileList) #文件数目
trainMat=zeros((m,1024))
#从文件名中解析分类信息,如0_13.txt
for i in range(m):
fileNameStr=trainingFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
hwLabels.append(classNumber)
trainMat[i]=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits/%s'%fileNameStr)
testFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits')
errorCount=0
#同上,解析测试数据的分类信息
mTest=len(testFileList)
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
vectorUnderTest=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits/%s'%fileNameStr)
classifierResult=classify(vectorUnderTest,trainMat,hwLabels,3)
print('the classifier came back with :%d,the real answer is:%d'%(classifierResult,classNumber))
if(classifierResult!=classNumber):errorCount+=1
print('\n the total number of errors is: %d'%errorCount)
print('\n total error rate is %f'%(errorCount/float(mTest)))
接下来在Python控制台输入命令来测试手写数字识别:
reload(KNN)
KNN.handwritingClassTest()
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
... ...
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the total number of errors is: 10
total error rate is 0.010571
错误利率1.057%,具有较高的准确率。
CDA学员免费下载查看报告全文:2026全球数智化人才指数报告【CDA数据科学研究院】.pdf
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
近日,由 CDA 数据科学研究院重磅发布的《2026 全球数智化人才指数报告》,被中国教育科学研究院官方账号正式收录, ...
2026-04-22在数字化时代,客户每一次点击、浏览、下单、咨询等行为,都在传递其潜在需求与决策倾向——这些按时间顺序串联的行为轨迹,构成 ...
2026-04-22数据是数据分析、建模与业务决策的核心基石,而“数据清洗”作为数据预处理的核心环节,是打通数据从“原始杂乱”到“干净可用” ...
2026-04-22 很多数据分析师每天盯着GMV、转化率、DAU等数字看,但当被问到“什么是指标”“指标和维度有什么区别”“如何搭建一套完整的 ...
2026-04-22在数据分析与业务决策中,数据并非静止不变的数值,而是始终处于动态波动之中——股市收盘价的每日涨跌、企业月度销售额的起伏、 ...
2026-04-21在数据分析领域,当研究涉及多个自变量与多个因变量之间的复杂关联时,多变量一般线性分析(Multivariate General Linear Analys ...
2026-04-21很多数据分析师精通描述性统计,能熟练计算均值、中位数、标准差,但当被问到“用500个样本如何推断10万用户的真实满意度”“这 ...
2026-04-21在数据处理与分析的全流程中,日期数据是贯穿业务场景的核心维度之一——无论是业务报表统计、用户行为追踪,还是风控规则落地、 ...
2026-04-20在机器学习建模全流程中,特征工程是连接原始数据与模型效果的关键环节,而特征重要性分析则是特征工程的“灵魂”——它不仅能帮 ...
2026-04-20很多数据分析师沉迷于复杂的机器学习算法,却忽略了数据分析最基础也最核心的能力——描述性统计。事实上,80%的商业分析问题, ...
2026-04-20在数字化时代,数据已成为企业决策的核心驱动力,数据分析与数据挖掘作为解锁数据价值的关键手段,广泛应用于互联网、金融、医疗 ...
2026-04-17在数据处理、后端开发、报表生成与自动化脚本中,将 SQL 查询结果转换为字符串是一项高频且实用的操作。无论是拼接多行数据为逗 ...
2026-04-17面对一份上万行的销售明细表,要快速回答“哪个地区卖得最好”“哪款产品增长最快”“不同客户类型的购买力如何”——这些看似复 ...
2026-04-17数据分析师一天的工作,80% 的时间围绕表格结构数据展开。从一张销售明细表到一份完整的分析报告,表格结构数据贯穿始终。但你真 ...
2026-04-16在机器学习无监督学习领域,Kmeans聚类因其原理简洁、计算高效、可扩展性强的优势,成为数据聚类任务中的主流算法,广泛应用于用 ...
2026-04-16在机器学习建模实践中,特征工程是决定模型性能的核心环节之一。面对高维数据集,冗余特征、无关特征不仅会增加模型训练成本、延 ...
2026-04-16在数字化时代,用户是产品的核心资产,用户运营的本质的是通过科学的指标监测、分析与优化,实现“拉新、促活、留存、转化、复购 ...
2026-04-15在企业数字化转型、系统架构设计、数据治理与AI落地过程中,数据模型、本体模型、业务模型是三大核心基础模型,三者相互支撑、各 ...
2026-04-15数据分析师的一天,80%的时间花在表格数据上,但80%的坑也踩在表格数据上。 如果你分不清数值型和文本型的区别,不知道数据从哪 ...
2026-04-15在人工智能与机器学习落地过程中,模型质量直接决定了应用效果的优劣——无论是分类、回归、生成式模型,还是推荐、预测类模型, ...
2026-04-14