K 近邻算法(K-Nearest Neighbor)

在机器学习中,K 近邻算法(KNN)是一种经典的监督学习方法,用于分类和回归问题。尽管它是一个简单的算法,但在许多实际应用中,它依然表现出色。尤其是在以下几个场景:

  • 文本分类:通过计算文本之间的相似度,KNN 可以用于文本分类任务,例如垃圾邮件分类等。
  • 图像识别:可以用于图像的分类,通过计算图像之间的距离,找到最相似的图像并进行分类。
  • 推荐系统:可以根据用户的历史行为或偏好,推荐与其兴趣相似的商品或内容。
  • 医疗诊断:可以用于根据病人的症状和历史病例来预测疾病类型。

1. 算法原理

KNN 是一种懒学习(Lazy Learning)算法,意味着它在训练阶段几乎没有任何学习过程。算法只是简单地存储训练数据。

在预测阶段,KNN 根据测试数据点与训练数据集中的所有数据点之间的距离来选择最近的 K 个邻居。然后,算法根据邻居的标签进行预测:

  • 分类问题:通过对 K 个邻居的类别进行投票,选择出现次数最多的类别作为预测结果。
  • 回归问题:通过对 K 个邻居的数值进行平均,得到预测结果。

KNN 算法中,最常用的距离度量方法是欧氏距离,当然,还可以使用其他距离度量方法,如曼哈顿距离、切比雪夫距离等。

K 值是算法中最重要的超参数。通常情况下,较小的 K 值会导致模型过拟合,而较大的 K 值则可能导致欠拟合。因此,选择合适的 K 值至关重要。一种常见的选择 K 值的方法是交叉验证。通过在训练集上进行多次交叉验证,选择使得模型在验证集上表现最好的 K 值。

2. 算法不足

KNN 算法的第一个优点是简单易懂。它的原理非常直观,几乎没有复杂的数学推导。因此,无论是在学习还是实际应用中,KNN 都容易被理解和实现。另外,KNN 不需要对数据的分布做任何假设。这使得 KNN 在处理一些难以用传统模型描述的复杂数据分布时,依然能够有效工作。

然而,KNN 也有一些显著的缺点。首先是计算开销大。KNN 在预测阶段需要计算所有训练数据点与测试点之间的距离,随着数据量的增加,计算量会迅速增大。这使得它在处理大规模数据时,计算成本显得非常高。

其次,存储的需求也很高。由于 KNN 需要存储整个训练数据集以便进行距离计算,它在存储资源上的开销也相对较大,特别是在面对海量数据时。

最后,KNN 对数据中的噪声非常敏感。如果训练集或测试集包含噪声或异常值,KNN 可能会将其错误地作为邻居影响预测结果,导致预测的准确性下降。因此,KNN 对噪声的敏感性使得它在处理不干净数据时表现不佳。

3. 算法优化

KNN 算法的计算复杂度随着数据的维度增加而增长。这意味着,当数据集具有较高的维度时,KNN 算法的效率会显著降低,且其效果可能会受到“维度灾难”的影响。维度灾难是指随着特征空间维度的增高,样本点之间的距离会变得越来越相似,从而使得距离度量失去有效性,导致算法性能下降。

为了解决这个问题,可以使用降维方法,比如主成分分析(PCA)。PCA 是一种常用的降维技术,它通过找到数据中的主成分,将数据投影到较低维的空间中,从而降低数据的维度并保留数据中的大部分信息。降维不仅可以减少计算负担,还能提高 KNN 在高维数据集上的性能。

为了提高 KNN 在大规模数据集上的查询效率,KD 树(K-dimensional tree)和Ball 树(Ball tree)是常用的数据结构。它们通过不同的方式对训练数据进行组织,使得在计算邻居时,不必遍历整个数据集。

  • KD 树:是一种将数据分割成 K 个维度的树形结构。每次分割时,选择一个维度进行划分,这样可以加速邻居查询,尤其在低维数据集中表现优越。
  • Ball 树:则是基于球形区域进行数据划分的树形结构,适用于高维数据集,能够在更高维的空间中进行更高效的邻居搜索。

使用这些数据结构可以显著减少 KNN 算法在大数据集上的计算时间,使得 KNN 更加高效。

采用加权投票法。与简单投票法不同,在加权投票法中,邻居的影响力与它们与测试点的距离成反比。具体来说,距离测试点越近的邻居,其投票权重越大。这样做的目的是减少远离测试点的邻居对最终结果的影响,从而提高预测的准确性。加权投票法常常能够在某些数据集上提供比简单投票法更好的效果。

4. 算法使用

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split


def iris_data():
    """
    鸢尾花数据信息
    """

    # 1. 加载红酒数据集
    dataset = load_iris()
    # 样本数量
    print("样本数量:", len(dataset.data))
    # 类别名称
    print("类别名称:", dataset.target_names)
    # 特征名称
    print("特征名称:", dataset.feature_names)
    # 样本类别
    print("样本类别:", dataset.target)
    # 样本数据
    print("样本数据:", dataset.data)


# knn API
def knn_api():

    # 1. 加载鸢尾花数据
    iris_dataset = load_iris()

    # 2. 创建 knn 分类对象
    knn = KNeighborsClassifier(n_neighbors=3)

    # 3. 将数据集划分为训练集和测试集[测试集占所有数据集的20%]
    X_train, X_test, y_train, y_test = \
        train_test_split(iris_dataset.data, iris_dataset.target, test_size=0.2)

    # 4. 训练数据
    knn.fit(X_train, y_train)

    # 5. 进行预测
    for index in range(len(X_test)):

        # 5.1 计算当前测试样本属于哪个分类
        class_label = knn.predict([X_test[index]])
        # 5.2 计算当前测试样本属于每个分类的概率
        class_proba = knn.predict_proba([X_test[index]])
        # 5.3 返回距离当前测试样本最近的 N 个样本
        nearest_neighbors = knn.kneighbors([X_test[index]], 5, False)

    # 预测准确率
    predict_score = knn.score(X_test, y_test, sample_weight=None)
    print("准确率:", predict_score)


if __name__ == "__main__":
    knn_api()
未经允许不得转载:一亩三分地 » K 近邻算法(K-Nearest Neighbor)
评论 (0)

9 + 1 =