forked from bluehaotian/Statistical-Learning-Methods
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_kdtree.py
More file actions
21 lines (17 loc) · 760 Bytes
/
test_kdtree.py
File metadata and controls
21 lines (17 loc) · 760 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import knn_kdtree
import numpy as np
X = np.array([[1, 1], [1, 2], [1, 3], [2, 2], [3, 1], [3, 2], [3, 3]])
Y = np.array([0] * len(X))
tree = knn_kdtree.KDTree(X, Y)
def points_equal(a, b):
a = set(map(tuple, a))
b = set(map(tuple, b))
return a == b
assert(points_equal(tree.root.points, [[2, 2]]))
assert(points_equal(tree.root.left.points, [[1, 2]]))
assert(points_equal(tree.root.right.points, [[3, 2]]))
assert(points_equal(tree.root.left.left.points, [[1, 1]]))
assert(points_equal(tree.root.left.right.points, [[1, 3]]))
assert(points_equal(tree.root.right.left.points, [[3, 1]]))
assert(points_equal(tree.root.right.right.points, [[3, 3]]))
assert(points_equal([a[0] for a in tree.query(np.array([2, 1]), 3)], [[1, 1], [2, 2], [3, 1]]))