100天机器学习项目实战:K最近邻(K-NN)算法详解与应用
2025-07-05 01:42:41作者:郁楠烈Hubert
什么是K最近邻算法
K最近邻(K-Nearest Neighbors,简称K-NN)是一种简单而强大的监督学习算法,既可以用于分类问题,也可以用于回归问题。它的核心思想是"物以类聚"——一个样本的类别或值由其周围最近的K个邻居决定。
K-NN算法工作原理
- 计算距离:对于测试集中的每个样本,计算它与训练集中所有样本的距离(通常使用欧氏距离)
- 选择邻居:选取距离最近的K个训练样本
- 投票决策:对于分类问题,统计K个邻居中各类别的数量,将测试样本归为数量最多的类别;对于回归问题,则取K个邻居的平均值
项目实战:社交网络广告分类
1. 数据集理解
我们使用的数据集包含社交网络用户的以下信息:
- 用户ID
- 性别
- 年龄
- 估计薪资
- 是否购买了产品(目标变量)
2. 数据预处理
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 导入数据集
dataset = pd.read_csv('Social_Network_Ads.csv')
X = dataset.iloc[:, [2, 3]].values # 选择年龄和薪资作为特征
y = dataset.iloc[:, 4].values # 是否购买作为目标变量
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
# 特征缩放
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
3. 模型训练与预测
from sklearn.neighbors import KNeighborsClassifier
# 创建KNN分类器
classifier = KNeighborsClassifier(n_neighbors=5, metric='minkowski', p=2)
classifier.fit(X_train, y_train)
# 预测测试集结果
y_pred = classifier.predict(X_test)
4. 模型评估
使用混淆矩阵评估模型性能:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
K-NN算法关键参数解析
-
n_neighbors (K值):
- 控制考虑的邻居数量
- 较小的K值可能导致过拟合,较大的K值可能导致欠拟合
- 通常通过交叉验证选择最佳K值
-
metric (距离度量):
- 常用的有欧氏距离('euclidean')、曼哈顿距离('manhattan')、闵可夫斯基距离('minkowski')等
- 对于文本数据,余弦相似度可能更合适
-
weights (权重):
- 'uniform':所有邻居权重相同
- 'distance':距离越近的邻居权重越大
K-NN算法的优缺点
优点:
- 简单直观,易于理解和实现
- 无需训练阶段(惰性学习)
- 适用于多分类问题
- 对异常值不敏感
缺点:
- 计算复杂度高,特别是大数据集
- 需要存储全部训练数据
- 对不相关的特征和噪声敏感
- 需要确定合适的K值
实际应用建议
- 特征缩放:K-NN对特征尺度敏感,务必进行标准化或归一化
- 降维处理:高维数据下距离计算可能失效,考虑使用PCA等方法降维
- K值选择:通过交叉验证选择最优K值,通常从3-10开始尝试
- 距离度量:根据数据类型选择合适的距离度量方式
可视化决策边界
理解K-NN决策过程的一个好方法是可视化其决策边界:
from matplotlib.colors import ListedColormap
X_set, y_set = X_train, y_train
X1, X2 = np.meshgrid(np.arange(start=X_set[:, 0].min()-1, stop=X_set[:, 0].max()+1, step=0.01),
np.arange(start=X_set[:, 1].min()-1, stop=X_set[:, 1].max()+1, step=0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
alpha=0.75, cmap=ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
c=ListedColormap(('red', 'green'))(i), label=j)
plt.title('K-NN (Training set)')
plt.xlabel('Age')
plt.ylabel('Estimated Salary')
plt.legend()
plt.show()
通过这个项目实战,我们不仅学习了K-NN算法的基本原理,还掌握了如何使用Python实现该算法解决实际问题。K-NN虽然简单,但在许多实际应用中表现优异,是机器学习工具箱中不可或缺的一部分。