-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkdtree2.py
179 lines (123 loc) · 3.63 KB
/
kdtree2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#------------------------MODULE TO IMPLEMENT KD-TREE------------------------------
#import math for math functions
import math
#counter function to count no. of recursive search queries
def cnt():
cnt.count+=1
cnt.count=0
#function to calculate square distance between two points 'a' and 'b'
def square_distance(a, b):
s = math.pow((a[0]-b[0]),2)+math.pow((a[1]-b[1]),2)
return s
#Defining a KDTree Node and it's attributes
class Node:
def __init__(self,pt,ax,l,lt,rt):
self.point=pt
self.axis=ax
self.label=l
self.left=lt
self.right=rt
class binaryheap:
def __init__(self,ele=[]):
self.elements=ele
self.n=len(ele)
self.buildheap()
def buildheap(self):
for i in range(self.n//2-1,-1,-1):
self.heapify(i)
def heapify(self,i):
l=2*i+1
r=2*i+2
large=i
if l<self.n and self.elements[large][2]<self.elements[l][2]:
large=l
if r<self.n and self.elements[large][2]<self.elements[r][2]:
large=r
if large!=i:
self.elements[large],self.elements[i]=self.elements[i],self.elements[large]
self.heapify(large)
def print(self):
for i in range(self.n):
print(self.elements[i]," , ",end='')
print()
def insert(self,x):
self.elements.append(x)
self.n+=1
i=self.n-1
while i>0 and self.elements[(i-1)//2][2]<self.elements[i][2]:
self.elements[(i-1)//2],self.elements[i]=self.elements[i],self.elements[(i-1)//2]
i=(i-1)//2
def extractmax(self):
self.elements[self.n-1],self.elements[0]=self.elements[0],self.elements[self.n-1]
a=self.elements[self.n-1]
self.elements.pop()
self.n-=1
self.heapify(0)
return a
def returnmax(self):
return self.elements[0]
#Implementing KDTree
class KDTree:
#Initialization using __init__
def __init__(self,objects=[]):
#Building the Tree
def build_tree(objects, axis=0):
if not objects:
return None
#Sorting the coordinates to find the median
objects.sort(key=lambda o: o[0][axis])
median_idx = len(objects) // 2
median_point, median_label = objects[median_idx]
#if axis=x, next_axis=y , vice versa
next_axis = (axis + 1) % 2
return Node(median_point, axis, median_label,
build_tree(objects[:median_idx], next_axis),
build_tree(objects[median_idx + 1:], next_axis))
self.root = build_tree(list(objects))
def nearest_neighbor(self,destination,t,n=0,r=None):
# state of search: best point found, its label,
# lowest squared distance
if t==1:
bestheap=binaryheap()
elif t==2:
best=[]
def recursive_search(here):
if here is None:
return
global length
#global counter update for every call
cnt()
point, axis, label, left, right = here.point,here.axis,here.label,here.left,here.right
here_sd = square_distance(point, destination)
# if t=1 find k nearest neighbours
if t==1:
if bestheap.n<n:
bestheap.insert([point,label,here_sd])
elif bestheap.returnmax()[2]>here_sd:
bestheap.extractmax()
bestheap.insert([point,label,here_sd])
#If t=2 find neighbours within a radius of r units
if t==2:
if here_sd < r*r:
best.append([point,label,here_sd])
diff = destination[axis] - point[axis]
close, away = (left, right) if diff <= 0 else (right, left)
recursive_search(close)
flag=0
if t==1:
for i in range(bestheap.n):
if diff**2<bestheap.elements[i][2]:
flag=1
if t==2:
if diff<r:
flag=1
if flag==1:
recursive_search(away)
recursive_search(self.root)
if t==1:
return bestheap
elif t==2:
return best
#Method for returning the count of recursive search queries
def returncounter(self):
return cnt.count