写点什么

Python 实现 KNN 算法

作者:TiAmo
  • 2023-05-31
    江苏
  • 本文字数:3379 字

    阅读完需:约 11 分钟

Python实现KNN算法

本篇我们将讨论一种广泛使用的分类技术,称为 k 邻近算法,或者说 K 最近邻(KNN,k-Nearest Neighbor)。所谓 K 最近邻,是 k 个最近的邻居的意思,即每个样本都可以用它最接近的 k 个邻居来代表。

01、KNN 算法思想

如果一个样本在特征空间中的 k 个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN 方法在类别决策时,只与极少量的相邻样本有关。


由于 KNN 方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN 方法较其他方法更为适合。

02、KNN 算法的决策过程

下图中有两种类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形,中间那个绿色的圆形是待分类数据:


▍近邻分类图

如果 K=3,那么离绿色点最近的有 2 个红色的三角形和 1 个蓝色的正方形,这三个点进行投票,于是绿色的待分类点就属于红色的三角形。而如果 K=5,那么离绿色点最近的有 2 个红色的三角形和 3 个蓝色的正方形,这五个点进行投票,于是绿色的待分类点就属于蓝色的正方形。 


KNN 算法不仅可以用于分类,还可以用于回归。通过找出一个样本的 k 个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比。


下面用代码来实现 KNN 算法的应用。本次用到的数据是经典的 Iris 数据集。该数据集有 150 条鸢尾花数据样本,并且均匀分布在 3 个不同的亚种:每个数据样本被 4 个不同的花瓣、花萼的形状特征所描述。

#读取数据from sklearn.datasets import load_irisdata = load_iris()#查看数据大小data.data.shape(150, 4)#查看数据说明print (data.DESCR)Notes-----Data Set Characteristics:    :Number of Instances: 150 (50 in each of three classes)    :Number of Attributes: 4 numeric, predictive attributes and the class    :Attribute Information:        - sepal length in cm        - sepal width in cm        - petal length in cm        - petal width in cm        - class:                - Iris-Setosa                - Iris-Versicolour                - Iris-Virginica    :Summary Statistics:
============== ==== ==== ======= ===== ==================== Min Max Mean SD Class Correlation ============== ==== ==== ======= ===== ==================== sepal length: 4.3 7.9 5.84 0.83 0.7826 sepal width: 2.0 4.4 3.05 0.43 -0.4194 petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) ============== ==== ==== ======= ===== ====================
:Missing Attribute Values: None :Class Distribution: 33.3% for each of 3 classes. :Creator: R.A. Fisher :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) :Date: July, 1988This is a copy of UCI ML iris datasets.http://archive.ics.uci.edu/ml/datasets/IrisThe famous Iris database, first used by Sir R.A FisherThis is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.
References---------- - Fisher,R.A. "The use of multiple measurements in taxonomic problems" Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to Mathematical Statistics" (John Wiley, NY, 1950). - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System Structure and Classification Rule for Recognition in Partially Exposed Environments". IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. PAMI-2, No. 1, 67-71. - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions on Information Theory, May 1972, 431-433. - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II conceptual clustering system finds 3 classes in the data. - Many, many more ...
复制代码

通过上述代码对数据的查验以及数据本身的描述,我们可以了解到 Iris 数据集共有 150 条鸢尾花数据样本,并且均匀分布在 3 个不同的亚种;每一个数据样本被 4 个不同的花瓣、花萼的形状特征所描述。由于没有指定的测试集,依据管理,我们需要第数据进行随机分割,25%的数据用作测试,75 的数据用作训练。


需要强调的是,如果读者朋友自行编写程序用作数据分割,请务必保证是随机采样。尽管很多数据集中的样本的排序相对随机,但是也有例外。本例中,Iris 数据就是根据类别一次排列的。如果只采样前 25%的数据用作测试,那么所有的测试样本都属于一个类别,同时训练样本也是不均衡的,这样得到的结果存在偏置,并且可信度非常低,Scikit-learn 所提供的数据分割模块是默认采用随机采样的功能的,因此大家可不必担心。

#对数据进行分割from sklearn.cross_validation import train_test_splitX_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size = 0.25, random_state = 33)
#使用KNN算法进行分类from sklearn.preprocessing import StandardScalerfrom sklearn.neighbors import KNeighborsClassifier#初始化ss = StandardScaler()
#数据标准化X_train = ss.fit_transform(X_train)X_test = ss.transform(X_test)
#训练模型knc = KNeighborsClassifier()knc.fit(X_train, y_train)#预测y_pred = knc.predict(X_test)
#模型评估print ('The accuracy of KNN is:', knc.score(X_test, y_test))from sklearn.metrics import classification_reportprint(classification_report(y_test, y_pred, target_names = data.target_names))
复制代码

代码输出结果如下,Knn 算法对鸢尾花测试数据的分类准确率为 89.474%,其他数据如下可见。


KNN 算法的特点分析:KNN 算法是非常直观的机器学习模型,因此深受广大初学者的喜爱。许多教科书往往一次模型抛砖引玉,便足以看出其不仅特别,而且尚有瑕疵之处。细心的读者会发现,KNN 算法与其他算法模型最大的不同在于:该模型没有参数训练过程。也就是说,我们并没有通过任何学习算法来分析训练数据,而只是根据测试样本在训练数据中的的分布直接做出分类决策。因此,KNN 算法属于无参数模型中非常简单的一种。然而,正是这样的决策算法,导致了其非常高的计算复杂度和内存消耗。因为该模型每处理一个测试样本,都需要对所有事先加载在内存中的训练样本进行遍历、逐一计算相似度、排序并且选取 K 个最近邻训练样本的标记,进而做出分类决策。这是平方级的算法复杂度,一旦数据规模稍大,便需要权衡更多计算时间的代价。


最后,对 KNN 算法做一个简单的小结:

优点

简单,易于理解,易于实现,无需估计参数,无需训练;

适合对稀有事件进行分类;

特别适合于多分类问题(multi-modal,对象具有多个类别标签),kNN 比 SVM 的表现要好。

缺点

当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的 K 个邻居中大容量类的样本占多数,少数类容易分错。

需要存储全部训练样本。

计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的 K 个最近邻点。

可理解性差,无法给出像决策树那样的规则。

发布于: 2023-05-31阅读数: 22
用户头像

TiAmo

关注

有能力爱自己,有余力爱别人! 2022-06-16 加入

CSDN全栈领域优质创作者,万粉博主;阿里云专家博主、星级博主、技术博主、阿里云问答官,阿里云MVP;华为云享专家;华为Iot专家;亚马逊人工智能自动驾驶(大众组)吉尼斯世界纪录获得者

评论

发布
暂无评论
Python实现KNN算法_算法_TiAmo_InfoQ写作社区