一、KNN 算法是什麼
KNN(k-Nearest Neighbors) 是最近鄰類算法中最經典的一種,用於:
- 分類問題(多數投票)
- 迴歸問題(均值 / 加權均值)
核心思想一句話:
一個樣本屬於哪一類,由“離它最近的 K 個樣本”決定。
KNN 沒有訓練過程,本質是 基於距離的搜索算法。
二、算法數學定義
給定:
- 訓練集:( D = {(x_1,y_1),...,(x_n,y_n)} )
- 查詢點:( q )
- 距離函數:( dist(x, q) )
步驟:
- 計算 ( dist(x_i, q) )
- 選取距離最小的 K 個樣本
- 分類: [ y = \arg\max_c \sum_{i \in K} I(y_i=c) ]
- 迴歸: [ y = \frac{1}{K}\sum_{i \in K} y_i ]
三、距離度量選擇(決定效果)
常用距離
| 距離 | 公式 | 場景 | ||||
|---|---|---|---|---|---|---|
| 歐氏距離 | ( \sqrt{\sum (x_i-y_i)^2} ) | 連續特徵 | ||||
| 曼哈頓距離 | ( \sum | x_i-y_i | ) | 稀疏特徵 | ||
| 餘弦距離 | ( 1-\frac{x·y}{ | x | y | } ) | 向量相似度 |
⚠️ 特徵必須先歸一化,否則距離無意義。
四、KNN 樸素實現
1️⃣ 算法流程
- 遍歷訓練集
- 計算距離
- 排序或維護 Top-K
- 聚合標籤
2️⃣ Python 實現(分類)
import math
from collections import Counter
def euclidean_distance(a, b):
return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))
def knn_classify(X_train, y_train, x_query, k=3):
distances = []
for x, y in zip(X_train, y_train):
d = euclidean_distance(x, x_query)
distances.append((d, y))
distances.sort(key=lambda x: x[0])
top_k = distances[:k]
labels = [label for _, label in top_k]
return Counter(labels).most_common(1)[0][0]
時間複雜度:
- 單次預測:
O(n · d)
五、KNN 迴歸實現
def knn_regression(X_train, y_train, x_query, k=3):
distances = []
for x, y in zip(X_train, y_train):
d = euclidean_distance(x, x_query)
distances.append((d, y))
distances.sort(key=lambda x: x[0])
top_k = distances[:k]
return sum(y for _, y in top_k) / k
六、K 的選擇與加權 KNN
1️⃣ K 的影響
-
K 小:
- 方差大
- 對噪聲敏感
-
K 大:
- 偏差大
- 邊界模糊
經驗:
K ≈ sqrt(n)
2️⃣ 加權 KNN
思想:
距離越近,權重越大
def weighted_knn(X_train, y_train, x_query, k=3):
distances = []
for x, y in zip(X_train, y_train):
d = euclidean_distance(x, x_query)
distances.append((d, y))
distances.sort(key=lambda x: x[0])
top_k = distances[:k]
weight_sum = {}
for d, y in top_k:
w = 1 / (d + 1e-8)
weight_sum[y] = weight_sum.get(y, 0) + w
return max(weight_sum.items(), key=lambda x: x[1])[0]
七、性能瓶頸與優化方向
1️⃣ 樸素 KNN 的問題
- 無法大規模使用
- 每次查詢都全量掃描
2️⃣ 常見優化方式
| 方法 | 説明 |
|---|---|
| KD-Tree | 低維精確搜索 |
| Ball Tree | 非歐氏空間 |
| ANN | 近似最近鄰(工業級) |
實際工程中,KNN ≈ 向量檢索問題。
八、sklearn 中的 KNN
from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(
n_neighbors=5,
weights='distance',
metric='euclidean'
)
model.fit(X_train, y_train)
pred = model.predict(X_test)
內部已自動使用 KD-Tree / Ball Tree。
九、KNN 的適用與不適用場景
適合
- 小數據集
- 特徵空間有明確距離意義
- 冷啓動問題
不適合
- 高維稠密數據
- 百萬級以上樣本
- 實時高併發
十、工程實踐建議
-
不要裸用 KNN 上生產
-
小數據:KNN + KD-Tree
-
大規模:HNSW / FAISS
-
真正效果來自:
- 特徵工程
- 距離設計
十一、總結
KNN 算法本質是:
距離定義 + 搜索策略 + 工程取捨