首页
/ 100天机器学习项目实战:K最近邻(K-NN)算法详解与应用

100天机器学习项目实战:K最近邻(K-NN)算法详解与应用

2025-07-05 01:42:41作者:郁楠烈Hubert

什么是K最近邻算法

K最近邻(K-Nearest Neighbors,简称K-NN)是一种简单而强大的监督学习算法,既可以用于分类问题,也可以用于回归问题。它的核心思想是"物以类聚"——一个样本的类别或值由其周围最近的K个邻居决定。

K-NN算法工作原理

  1. 计算距离:对于测试集中的每个样本,计算它与训练集中所有样本的距离(通常使用欧氏距离)
  2. 选择邻居:选取距离最近的K个训练样本
  3. 投票决策:对于分类问题,统计K个邻居中各类别的数量,将测试样本归为数量最多的类别;对于回归问题,则取K个邻居的平均值

项目实战:社交网络广告分类

1. 数据集理解

我们使用的数据集包含社交网络用户的以下信息:

  • 用户ID
  • 性别
  • 年龄
  • 估计薪资
  • 是否购买了产品(目标变量)

2. 数据预处理

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# 导入数据集
dataset = pd.read_csv('Social_Network_Ads.csv')
X = dataset.iloc[:, [2, 3]].values  # 选择年龄和薪资作为特征
y = dataset.iloc[:, 4].values  # 是否购买作为目标变量

# 划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)

# 特征缩放
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

3. 模型训练与预测

from sklearn.neighbors import KNeighborsClassifier

# 创建KNN分类器
classifier = KNeighborsClassifier(n_neighbors=5, metric='minkowski', p=2)
classifier.fit(X_train, y_train)

# 预测测试集结果
y_pred = classifier.predict(X_test)

4. 模型评估

使用混淆矩阵评估模型性能:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)

K-NN算法关键参数解析

  1. n_neighbors (K值)

    • 控制考虑的邻居数量
    • 较小的K值可能导致过拟合,较大的K值可能导致欠拟合
    • 通常通过交叉验证选择最佳K值
  2. metric (距离度量)

    • 常用的有欧氏距离('euclidean')、曼哈顿距离('manhattan')、闵可夫斯基距离('minkowski')等
    • 对于文本数据,余弦相似度可能更合适
  3. weights (权重)

    • 'uniform':所有邻居权重相同
    • 'distance':距离越近的邻居权重越大

K-NN算法的优缺点

优点

  • 简单直观,易于理解和实现
  • 无需训练阶段(惰性学习)
  • 适用于多分类问题
  • 对异常值不敏感

缺点

  • 计算复杂度高,特别是大数据集
  • 需要存储全部训练数据
  • 对不相关的特征和噪声敏感
  • 需要确定合适的K值

实际应用建议

  1. 特征缩放:K-NN对特征尺度敏感,务必进行标准化或归一化
  2. 降维处理:高维数据下距离计算可能失效,考虑使用PCA等方法降维
  3. K值选择:通过交叉验证选择最优K值,通常从3-10开始尝试
  4. 距离度量:根据数据类型选择合适的距离度量方式

可视化决策边界

理解K-NN决策过程的一个好方法是可视化其决策边界:

from matplotlib.colors import ListedColormap

X_set, y_set = X_train, y_train
X1, X2 = np.meshgrid(np.arange(start=X_set[:, 0].min()-1, stop=X_set[:, 0].max()+1, step=0.01),
                     np.arange(start=X_set[:, 1].min()-1, stop=X_set[:, 1].max()+1, step=0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha=0.75, cmap=ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c=ListedColormap(('red', 'green'))(i), label=j)
plt.title('K-NN (Training set)')
plt.xlabel('Age')
plt.ylabel('Estimated Salary')
plt.legend()
plt.show()

通过这个项目实战,我们不仅学习了K-NN算法的基本原理,还掌握了如何使用Python实现该算法解决实际问题。K-NN虽然简单,但在许多实际应用中表现优异,是机器学习工具箱中不可或缺的一部分。