rstar/algorithm/
nearest_neighbor.rs

1use crate::node::{ParentNode, RTreeNode};
2use crate::point::{min_inline, Point};
3use crate::{Envelope, PointDistance, RTreeObject};
4
5use alloc::collections::BinaryHeap;
6#[cfg(not(test))]
7use alloc::{vec, vec::Vec};
8use core::mem::replace;
9use heapless::binary_heap as static_heap;
10use num_traits::Bounded;
11
12struct RTreeNodeDistanceWrapper<'a, T>
13where
14    T: PointDistance + 'a,
15{
16    node: &'a RTreeNode<T>,
17    distance: <<T::Envelope as Envelope>::Point as Point>::Scalar,
18}
19
20impl<'a, T> PartialEq for RTreeNodeDistanceWrapper<'a, T>
21where
22    T: PointDistance,
23{
24    fn eq(&self, other: &Self) -> bool {
25        self.distance == other.distance
26    }
27}
28
29impl<'a, T> PartialOrd for RTreeNodeDistanceWrapper<'a, T>
30where
31    T: PointDistance,
32{
33    fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
34        Some(self.cmp(other))
35    }
36}
37
38impl<'a, T> Eq for RTreeNodeDistanceWrapper<'a, T> where T: PointDistance {}
39
40impl<'a, T> Ord for RTreeNodeDistanceWrapper<'a, T>
41where
42    T: PointDistance,
43{
44    fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
45        // Inverse comparison creates a min heap
46        other.distance.partial_cmp(&self.distance).unwrap()
47    }
48}
49
50impl<'a, T> NearestNeighborDistance2Iterator<'a, T>
51where
52    T: PointDistance,
53{
54    pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
55        let mut result = NearestNeighborDistance2Iterator {
56            nodes: SmallHeap::new(),
57            query_point,
58        };
59        result.extend_heap(&root.children);
60        result
61    }
62
63    fn extend_heap(&mut self, children: &'a [RTreeNode<T>]) {
64        let &mut NearestNeighborDistance2Iterator {
65            ref mut nodes,
66            ref query_point,
67        } = self;
68        nodes.extend(children.iter().map(|child| {
69            let distance = match child {
70                RTreeNode::Parent(ref data) => data.envelope.distance_2(query_point),
71                RTreeNode::Leaf(ref t) => t.distance_2(query_point),
72            };
73
74            RTreeNodeDistanceWrapper {
75                node: child,
76                distance,
77            }
78        }));
79    }
80}
81
82impl<'a, T> Iterator for NearestNeighborDistance2Iterator<'a, T>
83where
84    T: PointDistance,
85{
86    type Item = (&'a T, <<T::Envelope as Envelope>::Point as Point>::Scalar);
87
88    fn next(&mut self) -> Option<Self::Item> {
89        while let Some(current) = self.nodes.pop() {
90            match current {
91                RTreeNodeDistanceWrapper {
92                    node: RTreeNode::Parent(ref data),
93                    ..
94                } => {
95                    self.extend_heap(&data.children);
96                }
97                RTreeNodeDistanceWrapper {
98                    node: RTreeNode::Leaf(ref t),
99                    distance,
100                } => {
101                    return Some((t, distance));
102                }
103            }
104        }
105        None
106    }
107}
108
109pub struct NearestNeighborDistance2Iterator<'a, T>
110where
111    T: PointDistance + 'a,
112{
113    nodes: SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
114    query_point: <T::Envelope as Envelope>::Point,
115}
116
117impl<'a, T> NearestNeighborIterator<'a, T>
118where
119    T: PointDistance,
120{
121    pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
122        NearestNeighborIterator {
123            iter: NearestNeighborDistance2Iterator::new(root, query_point),
124        }
125    }
126}
127
128impl<'a, T> Iterator for NearestNeighborIterator<'a, T>
129where
130    T: PointDistance,
131{
132    type Item = &'a T;
133
134    fn next(&mut self) -> Option<Self::Item> {
135        self.iter.next().map(|(t, _distance)| t)
136    }
137}
138
139pub struct NearestNeighborIterator<'a, T>
140where
141    T: PointDistance + 'a,
142{
143    iter: NearestNeighborDistance2Iterator<'a, T>,
144}
145
146enum SmallHeap<T: Ord> {
147    Stack(static_heap::BinaryHeap<T, static_heap::Max, 32>),
148    Heap(BinaryHeap<T>),
149}
150
151impl<T: Ord> SmallHeap<T> {
152    pub fn new() -> Self {
153        Self::Stack(static_heap::BinaryHeap::new())
154    }
155
156    pub fn pop(&mut self) -> Option<T> {
157        match self {
158            SmallHeap::Stack(heap) => heap.pop(),
159            SmallHeap::Heap(heap) => heap.pop(),
160        }
161    }
162
163    pub fn push(&mut self, item: T) {
164        match self {
165            SmallHeap::Stack(heap) => {
166                if let Err(item) = heap.push(item) {
167                    let capacity = heap.len() + 1;
168                    let new_heap = self.spill(capacity);
169                    new_heap.push(item);
170                }
171            }
172            SmallHeap::Heap(heap) => heap.push(item),
173        }
174    }
175
176    pub fn extend<I>(&mut self, iter: I)
177    where
178        I: ExactSizeIterator<Item = T>,
179    {
180        match self {
181            SmallHeap::Stack(heap) => {
182                if heap.capacity() >= heap.len() + iter.len() {
183                    for item in iter {
184                        if heap.push(item).is_err() {
185                            unreachable!();
186                        }
187                    }
188                } else {
189                    let capacity = heap.len() + iter.len();
190                    let new_heap = self.spill(capacity);
191                    new_heap.extend(iter);
192                }
193            }
194            SmallHeap::Heap(heap) => heap.extend(iter),
195        }
196    }
197
198    #[cold]
199    fn spill(&mut self, capacity: usize) -> &mut BinaryHeap<T> {
200        let new_heap = BinaryHeap::with_capacity(capacity);
201        let old_heap = replace(self, SmallHeap::Heap(new_heap));
202
203        let new_heap = match self {
204            SmallHeap::Heap(new_heap) => new_heap,
205            SmallHeap::Stack(_) => unreachable!(),
206        };
207        let old_heap = match old_heap {
208            SmallHeap::Stack(old_heap) => old_heap,
209            SmallHeap::Heap(_) => unreachable!(),
210        };
211
212        new_heap.extend(old_heap.into_vec());
213
214        new_heap
215    }
216}
217
218pub fn nearest_neighbor<T>(
219    node: &ParentNode<T>,
220    query_point: <T::Envelope as Envelope>::Point,
221) -> Option<&T>
222where
223    T: PointDistance,
224{
225    fn extend_heap<'a, T>(
226        nodes: &mut SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
227        node: &'a ParentNode<T>,
228        query_point: <T::Envelope as Envelope>::Point,
229        min_max_distance: &mut <<T::Envelope as Envelope>::Point as Point>::Scalar,
230    ) where
231        T: PointDistance + 'a,
232    {
233        for child in &node.children {
234            let distance_if_less_or_equal = match child {
235                RTreeNode::Parent(ref data) => {
236                    let distance = data.envelope.distance_2(&query_point);
237                    if distance <= *min_max_distance {
238                        Some(distance)
239                    } else {
240                        None
241                    }
242                }
243                RTreeNode::Leaf(ref t) => {
244                    t.distance_2_if_less_or_equal(&query_point, *min_max_distance)
245                }
246            };
247            if let Some(distance) = distance_if_less_or_equal {
248                *min_max_distance = min_inline(
249                    *min_max_distance,
250                    child.envelope().min_max_dist_2(&query_point),
251                );
252                nodes.push(RTreeNodeDistanceWrapper {
253                    node: child,
254                    distance,
255                });
256            }
257        }
258    }
259
260    // Calculate smallest minmax-distance
261    let mut smallest_min_max: <<T::Envelope as Envelope>::Point as Point>::Scalar =
262        Bounded::max_value();
263    let mut nodes = SmallHeap::new();
264    extend_heap(&mut nodes, node, query_point.clone(), &mut smallest_min_max);
265    while let Some(current) = nodes.pop() {
266        match current {
267            RTreeNodeDistanceWrapper {
268                node: RTreeNode::Parent(ref data),
269                ..
270            } => {
271                extend_heap(&mut nodes, data, query_point.clone(), &mut smallest_min_max);
272            }
273            RTreeNodeDistanceWrapper {
274                node: RTreeNode::Leaf(ref t),
275                ..
276            } => {
277                return Some(t);
278            }
279        }
280    }
281    None
282}
283
284pub fn nearest_neighbors<T>(
285    node: &ParentNode<T>,
286    query_point: <T::Envelope as Envelope>::Point,
287) -> Vec<&T>
288where
289    T: PointDistance,
290{
291    let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone());
292
293    let (first, first_distance_2) = match nearest_neighbors.next() {
294        Some(item) => item,
295        // If we have an empty tree, just return an empty vector.
296        None => return Vec::new(),
297    };
298
299    // The result will at least contain the first nearest neighbor.
300    let mut result = vec![first];
301
302    // Use the distance to the first nearest neighbor
303    // to filter out the rest of the nearest neighbors
304    // that are farther than this first neighbor.
305    result.extend(
306        nearest_neighbors
307            .take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2)
308            .map(|(next, _)| next),
309    );
310
311    result
312}
313
314#[cfg(test)]
315mod test {
316    use crate::object::PointDistance;
317    use crate::rtree::RTree;
318    use crate::test_utilities::*;
319
320    #[test]
321    fn test_nearest_neighbor_empty() {
322        let tree: RTree<[f32; 2]> = RTree::new();
323        assert!(tree.nearest_neighbor(&[0.0, 213.0]).is_none());
324    }
325
326    #[test]
327    fn test_nearest_neighbor() {
328        let points = create_random_points(1000, SEED_1);
329        let tree = RTree::bulk_load(points.clone());
330
331        let sample_points = create_random_points(100, SEED_2);
332        for sample_point in &sample_points {
333            let mut nearest = None;
334            let mut closest_dist = f64::INFINITY;
335            for point in &points {
336                let delta = [point[0] - sample_point[0], point[1] - sample_point[1]];
337                let new_dist = delta[0] * delta[0] + delta[1] * delta[1];
338                if new_dist < closest_dist {
339                    closest_dist = new_dist;
340                    nearest = Some(point);
341                }
342            }
343            assert_eq!(nearest, tree.nearest_neighbor(sample_point));
344        }
345    }
346
347    #[test]
348    fn test_nearest_neighbors_empty() {
349        let tree: RTree<[f32; 2]> = RTree::new();
350        assert!(tree.nearest_neighbors(&[0.0, 213.0]).is_empty());
351    }
352
353    #[test]
354    fn test_nearest_neighbors() {
355        let points = create_random_points(1000, SEED_1);
356        let tree = RTree::bulk_load(points);
357
358        let sample_points = create_random_points(50, SEED_2);
359        for sample_point in &sample_points {
360            let nearest_neighbors = tree.nearest_neighbors(sample_point);
361            let mut distance = -1.0;
362            for nn in &nearest_neighbors {
363                if distance < 0.0 {
364                    distance = sample_point.distance_2(nn);
365                } else {
366                    let new_distance = sample_point.distance_2(nn);
367                    assert_eq!(new_distance, distance);
368                }
369            }
370        }
371    }
372
373    #[test]
374    fn test_nearest_neighbor_iterator() {
375        let mut points = create_random_points(1000, SEED_1);
376        let tree = RTree::bulk_load(points.clone());
377
378        let sample_points = create_random_points(50, SEED_2);
379        for sample_point in &sample_points {
380            points.sort_by(|r, l| {
381                r.distance_2(sample_point)
382                    .partial_cmp(&l.distance_2(sample_point))
383                    .unwrap()
384            });
385            let collected: Vec<_> = tree.nearest_neighbor_iter(sample_point).cloned().collect();
386            assert_eq!(points, collected);
387        }
388    }
389
390    #[test]
391    fn test_nearest_neighbor_iterator_with_distance_2() {
392        let points = create_random_points(1000, SEED_2);
393        let tree = RTree::bulk_load(points);
394
395        let sample_points = create_random_points(50, SEED_1);
396        for sample_point in &sample_points {
397            let mut last_distance = 0.0;
398            for (point, distance) in tree.nearest_neighbor_iter_with_distance_2(sample_point) {
399                assert_eq!(point.distance_2(sample_point), distance);
400                assert!(last_distance < distance);
401                last_distance = distance;
402            }
403        }
404    }
405}