基于kd树的knn实现

K-NN(K-Nearest Neighbor) 是一种经典的分类算法,相信你已经对算法有所了解了。KNN有一种经典实现 —— KD 树实现,但是却不是太好理解,因为 KD 树涉及到了高维空间描述,大家都知道三维空间的生物理解高维空间的事物都是很困难的。

本文主要讲解 KD 树版 KNN 实现,代码主要使用 Python.

KNN 的简单理解

形象地说 KNN 利用的是 近朱者赤近墨者黑 的原理,比如要判断一个人的性格,KNN 不对这个人本身直接分析判断,而是分析这个人所在的群体进行判断(因为往往群体的信息量更大,更好地采集数据)。

KD 树的简单理解

在计算机科学里,k-d树( k-dimension tree)是在k维欧几里德空间组织点的数据结构。 — 维基百科

k-d 树是每个节点都为 k 维点的二叉树,简单理解就是原始的二叉树存储的是一个值,但现在存储的是一个在 k 维空间上的点,作为三维空间的人类为了保护脑细胞本文均以二维空间举例,也就是说本文中二叉树结点存储的是二维空间上的点,即平面直角坐标系上的点 (x, y)

k-d 树中所有非叶子节点可以视作用一个超平面把空间分割成两个半空间 (也就是左右子树分别存储着两个半空间):节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。

选择超平面的方法如下: 每个节点都与 k 维中垂直于超平面的那一维有关。 因此,如果选择按照 x 轴划分,所有 x 值小于指定值 (为了高效的利用二叉树的空间,这个值一般是中位数) 的节点都会出现在左子树,所有 x 值大于指定值的节点都会出现在右子树。 这样,超平面可以用该x值来确定,其法线为x轴的单位向量。

为了使得构造的 kd 树尽量平衡,一般会依次切换坐标轴,也就是说:如果当前结点是按 x 轴切分,那么它的子结点是按 y 轴切分的 (再次声明,本文是建立在二维空间基础上的)

算法流程

输入: 二维空间的数据集 T = { \((x_{1}, y_{1})\), \((x_{2}, y_{2})\), …, \((x_{n}, y_{y})\) }

KD 树构造过程:

  1. 构造根结点:先选择 x 轴作为最初的坐标轴,对当前空间上的点按照 x 坐标值进行排序,选择中位数所在的 y 轴作为切分超平面,超平面左边的点由左子树存储,右边的点由右子树存储

  1. 递归:对左右子树进行和 1. 一样的递归操作,不过坐标轴要切换 (父节点选择的坐标轴是 x 轴的话,那么子结点是 y 轴,反之是 x 轴)
  2. 结束:直到两个子区域没有实例了即可停止

整个过程如下图所示

实现步骤

构建 kd 树

既然 kd 树是一棵二叉树,那么首先想到的是递归构造了,之前已经说了根结点取当前子空间在所选坐标轴上的中位数,所以我们要对当前子空间上的点的所选轴坐标的值进行排序,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def build_kdtree(points, depth=0):
"""
在这里我们用depth来保存当前结点的深度,
深度主要用来进行坐标轴的切换
"""

n = len(points)
if n <= 0:
# 如果当前子空间已经没有点了则构建过程结束
return None

# 计算当前选择的坐标轴
axis = depth % 2

# 对当前子空间的点根据当前选择轴的值进行排序
sorted_points = sorted(points, key=lambda point: point[axis])

# 中位数取排序后坐标点的中间位置的数
median = n // 2

return {
# 当前根结点
"point": sorted_points[median],
# 将超平面左边的点交由左子结点递归操作
"left": build_kdtree(sorted[:median], depth+1),
# 同理,将超平面右边的点交由右子结点递归操作
"right": build_kdtree(sorted_points[median+1:], depth+1)
}

找邻近点1:距离计算

因为要找邻近的点,所以需要距离来很衡量,平面直角坐标系中两点 \(p_{1}(x_{1}, y_{1}), p2(x_{2}, y_{2})\) 距离的计算公式为:

\[dist = \sqrt{(x_{2} - x_{1})^{2} + (y_{2} - y_{1})^{2}}\]

Python 实现也很简单

1
2
3
4
5
6
7
8
9
10
import math

def distance(point1, point2):
x1, y1 = point1
x2, y2 = point2

# math 的 sqrt(*) 用来计算平方根
dist = math.sqrt( (x2-x1)**2 + (y2-y1)**2 )

return dist

找邻近点2:找到两结点离当前点更近的结点

为什么要做这一步呢?因为遍历二叉树的时候我们需要将当前点分别和根结点、左子结点、右子结点的距离进行比较。该函数实现如下

1
2
3
4
5
6
7
8
9
10
def closer_distance(point, p1, p2):
if p1 is None:
return p2
if p2 is None:
return p1

d1 = distance(point, p1)
d2 = distance(point, p2)

return p1 if d1 < d2 else p2

找邻近点3:从根结点出发,遍历 kd 树,找到与给定点最近的结点

这一步是找邻近点最后一步,不做过多解释,代码看懂了也就理解了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def kdtree_closest_point(root, point, depth=0):
if root is None:
return None

# 确定当前轴
axis = depth % 2

next_branch = None
opposite_branch = None

if point[axis] < root['point'][axis]:
next_branch = root['left']
opposite_branch = root['right']
else:
next_branch = root['right']
opposite_branch = root['left']

# 以下主要是比较当前点到根结点和两个子结点之间的距离

best = closer_distance(
point,
kdtree_closest_point(
next_branch,
point,
depth + 1),
root['point']
)

# 当前结点和根结点在所选轴上的距离
# abs(*) : 取绝对值
cnt_dist = abs(point[axis] - root['point'][axis])

if distance(point, best) > cnt_dist:

best = closer_distance(
point,
kdtree_closest_point(
opposite_branch,
point,
depth + 1),
best
)

return best

完整代码

knn 预测时往往不是拿最邻近点的标签作为最后的预测(避免出现鹤立鸡群的情况),而是取 topk 个距离最近的点通过投票的机制来选择票数最高的标签作为最后的输出

完整的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-

import math
from collections import defaultdict

class KNN(object):
def __init__(self, topk=3):
self.data = None
self.store = {}
self.topk = topk

def build_kdtree(self, points, depth=0):
n = len(points)
if n <= 0:
# 如果当前子空间已经没有点了则构建过程结束
return None

# 计算当前选择的坐标轴
axis = depth % 2
# 对当前子空间的点根据当前选择轴的值进行排序
sorted_points = sorted(points, key=lambda point: point[axis])
# 中位数取排序后坐标点的中间位置的数
median = n // 2

return {
# 当前根结点
'point': sorted_points[median],
# 将超平面左边的点交由左子结点递归操作
'left': self.build_kdtree(sorted_points[:median], depth+1),
# 同理,将超平面右边的点交由右子结点递归操作
'right': self.build_kdtree(sorted_points[median+1:], depth+1)
}

def distance(self, p1, p2):
if p1 is None or p2 is None:
return 0

x1, y1 = p1
x2, y2 = p2

x_ = x2 - x1
y_ = y2 - y1

return math.sqrt(x_**2 + y_**2)

def closer_distance(self, point, p1, p2):

d1 = self.distance(point, p1)
d2 = self.distance(point, p2)

if p1 is None:
return (p2, d2)
if p2 is None:
return (p1, d1)

return (p1, d1) if d1 < d2 else (p2, d2)

def kdtree_closest_point(self, root, point, depth=0):
if root is None:
return None

axis = depth % 2

next_branch = None
opposite_branch = None

# 以下主要是比较当前点到根结点和两个子结点之间的距离
if point[axis] < root['point'][axis]:
next_branch = root['left']
opposite_branch = root['right']
else:
next_branch = root['right']
opposite_branch = root['left']

best, closer_dist = self.closer_distance(
point,
self.kdtree_closest_point(
next_branch,
point,
depth + 1),
root['point']
)

if self.distance(point, best) > abs(point[axis] - root['point'][axis]):

best, closer_dist = self.closer_distance(
point,
self.kdtree_closest_point(
opposite_branch,
point,
depth + 1),
best
)

# 储存距离,留作投票用
if best in self.store and self.store[best] > closer_dist:
self.store[best] = closer_dist
else:
self.store[best] = closer_dist

return best

def fit(self, X, y):
self.data = dict(zip(X, y))
self.kdtree = self.build_kdtree(X)

def predict(self, point):
# best 是最邻近的点
best = self.kdtree_closest_point(self.kdtree, point)

sorted_stores = sorted(self.store.items(), key=lambda x: x[1])[:self.topk]
counter = defaultdict(int)
for candidates, score in sorted_stores:
counter[self.data[candidates]] += 1

# 按照投票数降序排列
sorted_counter = sorted(counter.items(), key=lambda x: -x[1])
counter = list(counter.items())

if len(counter) > 1:
if counter[0][1] != counter[1][1]:
best = counter[0][1]

return self.data[best]

if __name__ == '__main__':

# 训练数据
points = [(1, 1), (1, 1.2), (0, 0), (0, 0.2), (3, 0.5), (3.3, 0.9)]
labels = ['A', 'A', 'B', 'B', 'C', 'C']

knn = KNN(topk=3)
# 开始训练
knn.fit(points, labels)
# 预测
label = knn.predict((0.9,0.9))
print(label)

代码也可查看 github / knn

Refrence

sean lee wechat
欢迎关注我的公众号!
感谢支持!