1use crate::node::{ParentNode, RTreeNode};
2use crate::point::{min_inline, Point};
3use crate::{Envelope, PointDistance, RTreeObject};
4
5use alloc::collections::BinaryHeap;
6#[cfg(not(test))]
7use alloc::{vec, vec::Vec};
8use core::mem::replace;
9use heapless::binary_heap as static_heap;
10use num_traits::Bounded;
11
12struct RTreeNodeDistanceWrapper<'a, T>
13where
14 T: PointDistance + 'a,
15{
16 node: &'a RTreeNode<T>,
17 distance: <<T::Envelope as Envelope>::Point as Point>::Scalar,
18}
19
20impl<'a, T> PartialEq for RTreeNodeDistanceWrapper<'a, T>
21where
22 T: PointDistance,
23{
24 fn eq(&self, other: &Self) -> bool {
25 self.distance == other.distance
26 }
27}
28
29impl<'a, T> PartialOrd for RTreeNodeDistanceWrapper<'a, T>
30where
31 T: PointDistance,
32{
33 fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
34 Some(self.cmp(other))
35 }
36}
37
38impl<'a, T> Eq for RTreeNodeDistanceWrapper<'a, T> where T: PointDistance {}
39
40impl<'a, T> Ord for RTreeNodeDistanceWrapper<'a, T>
41where
42 T: PointDistance,
43{
44 fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
45 other.distance.partial_cmp(&self.distance).unwrap()
47 }
48}
49
50impl<'a, T> NearestNeighborDistance2Iterator<'a, T>
51where
52 T: PointDistance,
53{
54 pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
55 let mut result = NearestNeighborDistance2Iterator {
56 nodes: SmallHeap::new(),
57 query_point,
58 };
59 result.extend_heap(&root.children);
60 result
61 }
62
63 fn extend_heap(&mut self, children: &'a [RTreeNode<T>]) {
64 let &mut NearestNeighborDistance2Iterator {
65 ref mut nodes,
66 ref query_point,
67 } = self;
68 nodes.extend(children.iter().map(|child| {
69 let distance = match child {
70 RTreeNode::Parent(ref data) => data.envelope.distance_2(query_point),
71 RTreeNode::Leaf(ref t) => t.distance_2(query_point),
72 };
73
74 RTreeNodeDistanceWrapper {
75 node: child,
76 distance,
77 }
78 }));
79 }
80}
81
82impl<'a, T> Iterator for NearestNeighborDistance2Iterator<'a, T>
83where
84 T: PointDistance,
85{
86 type Item = (&'a T, <<T::Envelope as Envelope>::Point as Point>::Scalar);
87
88 fn next(&mut self) -> Option<Self::Item> {
89 while let Some(current) = self.nodes.pop() {
90 match current {
91 RTreeNodeDistanceWrapper {
92 node: RTreeNode::Parent(ref data),
93 ..
94 } => {
95 self.extend_heap(&data.children);
96 }
97 RTreeNodeDistanceWrapper {
98 node: RTreeNode::Leaf(ref t),
99 distance,
100 } => {
101 return Some((t, distance));
102 }
103 }
104 }
105 None
106 }
107}
108
109pub struct NearestNeighborDistance2Iterator<'a, T>
110where
111 T: PointDistance + 'a,
112{
113 nodes: SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
114 query_point: <T::Envelope as Envelope>::Point,
115}
116
117impl<'a, T> NearestNeighborIterator<'a, T>
118where
119 T: PointDistance,
120{
121 pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
122 NearestNeighborIterator {
123 iter: NearestNeighborDistance2Iterator::new(root, query_point),
124 }
125 }
126}
127
128impl<'a, T> Iterator for NearestNeighborIterator<'a, T>
129where
130 T: PointDistance,
131{
132 type Item = &'a T;
133
134 fn next(&mut self) -> Option<Self::Item> {
135 self.iter.next().map(|(t, _distance)| t)
136 }
137}
138
139pub struct NearestNeighborIterator<'a, T>
140where
141 T: PointDistance + 'a,
142{
143 iter: NearestNeighborDistance2Iterator<'a, T>,
144}
145
146enum SmallHeap<T: Ord> {
147 Stack(static_heap::BinaryHeap<T, static_heap::Max, 32>),
148 Heap(BinaryHeap<T>),
149}
150
151impl<T: Ord> SmallHeap<T> {
152 pub fn new() -> Self {
153 Self::Stack(static_heap::BinaryHeap::new())
154 }
155
156 pub fn pop(&mut self) -> Option<T> {
157 match self {
158 SmallHeap::Stack(heap) => heap.pop(),
159 SmallHeap::Heap(heap) => heap.pop(),
160 }
161 }
162
163 pub fn push(&mut self, item: T) {
164 match self {
165 SmallHeap::Stack(heap) => {
166 if let Err(item) = heap.push(item) {
167 let capacity = heap.len() + 1;
168 let new_heap = self.spill(capacity);
169 new_heap.push(item);
170 }
171 }
172 SmallHeap::Heap(heap) => heap.push(item),
173 }
174 }
175
176 pub fn extend<I>(&mut self, iter: I)
177 where
178 I: ExactSizeIterator<Item = T>,
179 {
180 match self {
181 SmallHeap::Stack(heap) => {
182 if heap.capacity() >= heap.len() + iter.len() {
183 for item in iter {
184 if heap.push(item).is_err() {
185 unreachable!();
186 }
187 }
188 } else {
189 let capacity = heap.len() + iter.len();
190 let new_heap = self.spill(capacity);
191 new_heap.extend(iter);
192 }
193 }
194 SmallHeap::Heap(heap) => heap.extend(iter),
195 }
196 }
197
198 #[cold]
199 fn spill(&mut self, capacity: usize) -> &mut BinaryHeap<T> {
200 let new_heap = BinaryHeap::with_capacity(capacity);
201 let old_heap = replace(self, SmallHeap::Heap(new_heap));
202
203 let new_heap = match self {
204 SmallHeap::Heap(new_heap) => new_heap,
205 SmallHeap::Stack(_) => unreachable!(),
206 };
207 let old_heap = match old_heap {
208 SmallHeap::Stack(old_heap) => old_heap,
209 SmallHeap::Heap(_) => unreachable!(),
210 };
211
212 new_heap.extend(old_heap.into_vec());
213
214 new_heap
215 }
216}
217
218pub fn nearest_neighbor<T>(
219 node: &ParentNode<T>,
220 query_point: <T::Envelope as Envelope>::Point,
221) -> Option<&T>
222where
223 T: PointDistance,
224{
225 fn extend_heap<'a, T>(
226 nodes: &mut SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
227 node: &'a ParentNode<T>,
228 query_point: <T::Envelope as Envelope>::Point,
229 min_max_distance: &mut <<T::Envelope as Envelope>::Point as Point>::Scalar,
230 ) where
231 T: PointDistance + 'a,
232 {
233 for child in &node.children {
234 let distance_if_less_or_equal = match child {
235 RTreeNode::Parent(ref data) => {
236 let distance = data.envelope.distance_2(&query_point);
237 if distance <= *min_max_distance {
238 Some(distance)
239 } else {
240 None
241 }
242 }
243 RTreeNode::Leaf(ref t) => {
244 t.distance_2_if_less_or_equal(&query_point, *min_max_distance)
245 }
246 };
247 if let Some(distance) = distance_if_less_or_equal {
248 *min_max_distance = min_inline(
249 *min_max_distance,
250 child.envelope().min_max_dist_2(&query_point),
251 );
252 nodes.push(RTreeNodeDistanceWrapper {
253 node: child,
254 distance,
255 });
256 }
257 }
258 }
259
260 let mut smallest_min_max: <<T::Envelope as Envelope>::Point as Point>::Scalar =
262 Bounded::max_value();
263 let mut nodes = SmallHeap::new();
264 extend_heap(&mut nodes, node, query_point.clone(), &mut smallest_min_max);
265 while let Some(current) = nodes.pop() {
266 match current {
267 RTreeNodeDistanceWrapper {
268 node: RTreeNode::Parent(ref data),
269 ..
270 } => {
271 extend_heap(&mut nodes, data, query_point.clone(), &mut smallest_min_max);
272 }
273 RTreeNodeDistanceWrapper {
274 node: RTreeNode::Leaf(ref t),
275 ..
276 } => {
277 return Some(t);
278 }
279 }
280 }
281 None
282}
283
284pub fn nearest_neighbors<T>(
285 node: &ParentNode<T>,
286 query_point: <T::Envelope as Envelope>::Point,
287) -> Vec<&T>
288where
289 T: PointDistance,
290{
291 let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone());
292
293 let (first, first_distance_2) = match nearest_neighbors.next() {
294 Some(item) => item,
295 None => return Vec::new(),
297 };
298
299 let mut result = vec![first];
301
302 result.extend(
306 nearest_neighbors
307 .take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2)
308 .map(|(next, _)| next),
309 );
310
311 result
312}
313
314#[cfg(test)]
315mod test {
316 use crate::object::PointDistance;
317 use crate::rtree::RTree;
318 use crate::test_utilities::*;
319
320 #[test]
321 fn test_nearest_neighbor_empty() {
322 let tree: RTree<[f32; 2]> = RTree::new();
323 assert!(tree.nearest_neighbor(&[0.0, 213.0]).is_none());
324 }
325
326 #[test]
327 fn test_nearest_neighbor() {
328 let points = create_random_points(1000, SEED_1);
329 let tree = RTree::bulk_load(points.clone());
330
331 let sample_points = create_random_points(100, SEED_2);
332 for sample_point in &sample_points {
333 let mut nearest = None;
334 let mut closest_dist = f64::INFINITY;
335 for point in &points {
336 let delta = [point[0] - sample_point[0], point[1] - sample_point[1]];
337 let new_dist = delta[0] * delta[0] + delta[1] * delta[1];
338 if new_dist < closest_dist {
339 closest_dist = new_dist;
340 nearest = Some(point);
341 }
342 }
343 assert_eq!(nearest, tree.nearest_neighbor(sample_point));
344 }
345 }
346
347 #[test]
348 fn test_nearest_neighbors_empty() {
349 let tree: RTree<[f32; 2]> = RTree::new();
350 assert!(tree.nearest_neighbors(&[0.0, 213.0]).is_empty());
351 }
352
353 #[test]
354 fn test_nearest_neighbors() {
355 let points = create_random_points(1000, SEED_1);
356 let tree = RTree::bulk_load(points);
357
358 let sample_points = create_random_points(50, SEED_2);
359 for sample_point in &sample_points {
360 let nearest_neighbors = tree.nearest_neighbors(sample_point);
361 let mut distance = -1.0;
362 for nn in &nearest_neighbors {
363 if distance < 0.0 {
364 distance = sample_point.distance_2(nn);
365 } else {
366 let new_distance = sample_point.distance_2(nn);
367 assert_eq!(new_distance, distance);
368 }
369 }
370 }
371 }
372
373 #[test]
374 fn test_nearest_neighbor_iterator() {
375 let mut points = create_random_points(1000, SEED_1);
376 let tree = RTree::bulk_load(points.clone());
377
378 let sample_points = create_random_points(50, SEED_2);
379 for sample_point in &sample_points {
380 points.sort_by(|r, l| {
381 r.distance_2(sample_point)
382 .partial_cmp(&l.distance_2(sample_point))
383 .unwrap()
384 });
385 let collected: Vec<_> = tree.nearest_neighbor_iter(sample_point).cloned().collect();
386 assert_eq!(points, collected);
387 }
388 }
389
390 #[test]
391 fn test_nearest_neighbor_iterator_with_distance_2() {
392 let points = create_random_points(1000, SEED_2);
393 let tree = RTree::bulk_load(points);
394
395 let sample_points = create_random_points(50, SEED_1);
396 for sample_point in &sample_points {
397 let mut last_distance = 0.0;
398 for (point, distance) in tree.nearest_neighbor_iter_with_distance_2(sample_point) {
399 assert_eq!(point.distance_2(sample_point), distance);
400 assert!(last_distance < distance);
401 last_distance = distance;
402 }
403 }
404 }
405}