-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrajectory_tree.py
75 lines (58 loc) · 2.65 KB
/
trajectory_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright 2024 Bunting Labs, Inc.
import heapq
from collections import defaultdict
from functools import lru_cache
import numpy as np
from qgis.core import QgsPointXY
class TrajectoryTree:
def __init__(self, pts_costs, params, img_params, trajectory_root):
self.params = params # (x_min, y_max, dxdy, y_max)
self.img_params = img_params # (img_height, img_width)
self.trajectory_root = trajectory_root
# Hidden, with a setter
self.graph_neighbors = defaultdict(list)
for path, cost in pts_costs.items():
# NaN, inf, -inf all serialize to null in JSON
if cost is None:
continue
orig, dest = map(int, path.split('_'))
if orig == dest:
continue
self.graph_neighbors[orig].append((dest, cost))
@lru_cache(maxsize=1)
def _graph_nodes_coords(self):
graph_nodes = set(self.graph_neighbors.keys()).union(dest for dest, _ in sum(self.graph_neighbors.values(), []))
return [(np.unravel_index(int(node), (self.img_params[0], self.img_params[1])), node) for node in graph_nodes]
# TODO use a kd-tree
def closest_nodes_to(self, pt: QgsPointXY, n: int):
x_min, y_min, dxdy, y_max = self.params
img_x, img_y = (pt.x() - x_min*256*dxdy) / dxdy, (y_max*256*dxdy - pt.y()) / dxdy
graph_nodes_coords = self._graph_nodes_coords()
dists = [((img_x - x) ** 2 + (img_y - y) ** 2, node) for ((y, x), node) in graph_nodes_coords]
# return closest n
return [node for _, node in sorted(dists)[:n]]
@lru_cache(maxsize=100)
def dijkstra(self, end: int):
start = self.trajectory_root
queue = [(0, start)]
distances = {start: 0}
previous_nodes = {start: None}
while queue:
current_distance, current_node = heapq.heappop(queue)
if current_node == end:
break
for neighbor, edge_weight in self.graph_neighbors[current_node]:
new_distance = current_distance + edge_weight
if new_distance < distances.get(neighbor, float('inf')):
distances[neighbor] = new_distance
previous_nodes[neighbor] = current_node
heapq.heappush(queue, (new_distance, neighbor))
# If the end node wasn't reachable from the start
if end not in previous_nodes:
return [], float('inf')
path, current = [], end
while current is not None:
path.append(current)
current = previous_nodes[current]
path = path[::-1]
return path, distances.get(end, float('inf'))